Skip to content

Commit

Permalink
Refactoring + Improving Tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
gbg141 committed Apr 30, 2024
1 parent 5ca5120 commit 7063465
Show file tree
Hide file tree
Showing 22 changed files with 450 additions and 583 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
transform_type: 'lifting'
transform_name: "HypergraphKNNLifting"
k_value: 2
k_value: 3
loop: True
feature_lifting: ProjectionSum
File renamed without changes.
File renamed without changes.
File renamed without changes.
8 changes: 5 additions & 3 deletions modules/io/load/loaders.py → modules/data/load/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import torch_geometric
from omegaconf import DictConfig

from modules.io.load.base import AbstractLoader
from modules.io.utils.utils import (
from modules.data.load.base import AbstractLoader
from modules.data.utils.custom_dataset import CustomDataset
from modules.data.utils.utils import (
load_cell_complex_dataset,
load_hypergraph_pickle_dataset,
load_manual_graph,
Expand Down Expand Up @@ -102,7 +103,8 @@ def load(self) -> torch_geometric.data.Dataset:
dataset = datasets[0] + datasets[1] + datasets[2]

elif self.parameters.data_name in ["manual"]:
dataset = load_manual_graph()
data = load_manual_graph()
dataset = CustomDataset([data], self.data_dir)

else:
raise NotImplementedError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch_geometric

from modules.io.utils.utils import ensure_serializable, make_hash
from modules.data.utils.utils import ensure_serializable, make_hash
from modules.transforms.data_transform import DataTransform


Expand Down
15 changes: 15 additions & 0 deletions modules/data/utils/custom_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch_geometric


class CustomDataset(torch_geometric.data.InMemoryDataset):
def __init__(self, data_list, data_dir, transform=None):
self.data_list = data_list
super().__init__(data_dir, transform)
self.load(self.processed_paths[0])

@property
def processed_file_names(self):
return "data.pt"

def process(self):
self.save(self.data_list, self.processed_paths[0])
16 changes: 8 additions & 8 deletions modules/io/utils/utils.py → modules/data/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ def get_complex_connectivity(complex, max_rank, signed=False):
)
except ValueError:
if connectivity_info == "incidence":
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
else:
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity["shape"] = practical_shape
return connectivity
Expand Down
228 changes: 0 additions & 228 deletions modules/io/utils/split_utils.py

This file was deleted.

10 changes: 7 additions & 3 deletions modules/models/cell/cwn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@ class CWNModel(torch.nn.Module):
"""

def __init__(self, model_config, dataset_config):
in_channels_0 = dataset_config["num_features"]
in_channels_1 = dataset_config["num_features"]
in_channels_2 = dataset_config["num_features"]
in_channels_0 = (
dataset_config["num_features"]
if isinstance(dataset_config["num_features"], int)
else dataset_config["num_features"][0]
)
in_channels_1 = in_channels_0
in_channels_2 = in_channels_0
hidden_channels = model_config["hidden_channels"]
out_channels = dataset_config["num_classes"]
n_layers = model_config["n_layers"]
Expand Down
6 changes: 5 additions & 1 deletion modules/models/hypergraph/unigcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ class UniGCNModel(torch.nn.Module):
"""

def __init__(self, model_config, dataset_config):
in_channels = dataset_config["num_features"]
in_channels = (
dataset_config["num_features"]
if isinstance(dataset_config["num_features"], int)
else dataset_config["num_features"][0]
)
hidden_channels = model_config["hidden_channels"]
out_channels = dataset_config["num_classes"]
n_layers = model_config["n_layers"]
Expand Down
6 changes: 5 additions & 1 deletion modules/models/simplicial/san.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ class SANModel(torch.nn.Module):
"""

def __init__(self, model_config, dataset_config):
in_channels = dataset_config["num_features"]
in_channels = (
dataset_config["num_features"]
if isinstance(dataset_config["num_features"], int)
else dataset_config["num_features"][0]
)
hidden_channels = model_config["hidden_channels"]
out_channels = dataset_config["num_classes"]
n_layers = model_config["n_layers"]
Expand Down
2 changes: 1 addition & 1 deletion modules/transforms/liftings/graph2cell/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from toponetx.classes import CellComplex

from modules.io.utils.utils import get_complex_connectivity
from modules.data.utils.utils import get_complex_connectivity
from modules.transforms.liftings.lifting import GraphLifting


Expand Down
2 changes: 1 addition & 1 deletion modules/transforms/liftings/graph2simplicial/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from toponetx.classes import SimplicialComplex

from modules.io.utils.utils import get_complex_connectivity
from modules.data.utils.utils import get_complex_connectivity
from modules.transforms.liftings.lifting import GraphLifting


Expand Down
Loading

0 comments on commit 7063465

Please sign in to comment.