Skip to content

Commit

Permalink
Fix custom NeighborLoader test (#7680)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jul 3, 2023
1 parent e45059a commit 48df293
Showing 1 changed file with 22 additions and 29 deletions.
51 changes: 22 additions & 29 deletions test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest
import torch

import torch_geometric.typing
from torch_geometric.data import Data, HeteroData
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import GraphConv, to_hetero
Expand All @@ -25,7 +24,6 @@
from torch_geometric.utils import (
is_undirected,
sort_edge_index,
to_torch_csc_tensor,
to_torch_csr_tensor,
to_undirected,
)
Expand Down Expand Up @@ -390,15 +388,16 @@ def test_custom_neighbor_loader():
feature_store.put_tensor(x, group_name='author', attr_name='x', index=None)

# COO:
edge_index = get_random_edge_index(100, 100, 500)
edge_index = get_random_edge_index(100, 100, 500, coalesce=True)
edge_index = edge_index[:, torch.randperm(edge_index.size(1))]
data['paper', 'to', 'paper'].edge_index = edge_index
coo = (edge_index[0], edge_index[1])
graph_store.put_edge_index(edge_index=coo,
edge_type=('paper', 'to', 'paper'),
layout='coo', size=(100, 100))

# CSR:
edge_index = get_random_edge_index(100, 200, 1000)
edge_index = get_random_edge_index(100, 200, 1000, coalesce=True)
data['paper', 'to', 'author'].edge_index = edge_index
adj = to_torch_csr_tensor(edge_index, size=(100, 200))
csr = (adj.crow_indices(), adj.col_indices())
Expand All @@ -407,17 +406,16 @@ def test_custom_neighbor_loader():
layout='csr', size=(100, 200))

# CSC:
if torch_geometric.typing.WITH_PT112:
edge_index = get_random_edge_index(200, 100, 1000)
data['author', 'to', 'paper'].edge_index = edge_index
adj = to_torch_csc_tensor(edge_index, size=(200, 100))
csc = (adj.row_indices(), adj.ccol_indices())
graph_store.put_edge_index(edge_index=csc,
edge_type=('author', 'to', 'paper'),
layout='csc', size=(200, 100))
edge_index = get_random_edge_index(200, 100, 1000, coalesce=True)
data['author', 'to', 'paper'].edge_index = edge_index
adj = to_torch_csr_tensor(edge_index.flip([0]), size=(100, 200))
csc = (adj.col_indices(), adj.crow_indices())
graph_store.put_edge_index(edge_index=csc,
edge_type=('author', 'to', 'paper'),
layout='csc', size=(200, 100))

# COO (sorted):
edge_index = get_random_edge_index(200, 200, 100)
edge_index = get_random_edge_index(200, 200, 100, coalesce=True)
edge_index = edge_index[:, edge_index[1].argsort()]
data['author', 'to', 'author'].edge_index = edge_index
coo = (edge_index[0], edge_index[1])
Expand All @@ -438,25 +436,20 @@ def test_custom_neighbor_loader():
assert len(loader1) == len(loader2)

for batch1, batch2 in zip(loader1, loader2):
# loader2 explicitly adds `num_nodes` to the batch
# `loader2` explicitly adds `num_nodes` to the batch:
assert len(batch1) + 1 == len(batch2)
assert batch1['paper'].batch_size == batch2['paper'].batch_size

# Mapped indices of neighbors may be differently sorted:
assert torch.allclose(batch1['paper'].x.sort()[0],
batch2['paper'].x.sort()[0])
assert torch.allclose(batch1['author'].x.sort()[0],
batch2['author'].x.sort()[0])

assert (batch1['paper', 'to', 'paper'].edge_index.size() == batch1[
'paper', 'to', 'paper'].edge_index.size())
assert (batch1['paper', 'to', 'author'].edge_index.size() == batch1[
'paper', 'to', 'author'].edge_index.size())
if torch_geometric.typing.WITH_PT112:
assert (batch1['author', 'to', 'paper'].edge_index.size() ==
batch1['author', 'to', 'paper'].edge_index.size())
assert (batch1['author', 'to', 'author'].edge_index.size() == batch1[
'author', 'to', 'author'].edge_index.size())
# Mapped indices of neighbors may be differently sorted ...
for node_type in data.node_types:
assert torch.allclose(
batch1[node_type].x.sort()[0],
batch2[node_type].x.sort()[0],
)

# ... but should sample the exact same number of edges:
for edge_type in data.edge_types:
assert batch1[edge_type].num_edges == batch2[edge_type].num_edges


@onlyOnline
Expand Down

0 comments on commit 48df293

Please sign in to comment.