Skip to content

Commit

Permalink
CuGraph Example Fixes (#9577)
Browse files Browse the repository at this point in the history
This PR makes the output of the single GPU example more readable. It
also resolves an import issue by moving some of the cuGraph imports
around, preventing the pytorch allocator from being set before replacing
it with the `rmm_pytorch_allocator`.

Resolves rapidsai/cugraph-gnn#26

---------

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
alexbarghi-nv and rusty1s authored Aug 7, 2024
1 parent 8c849a4 commit c809173
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed an issue where import order in the multi-GPU `cugraph` example could cause an `rmm` error ([#9577](https://github.com/pyg-team/pytorch_geometric/pull/9577))
- Made the output of the single-GPU `cugraph` example more readable ([#9577](https://github.com/pyg-team/pytorch_geometric/pull/9577))
- Fixed `load_state_dict` behavior with lazy parameters in `HeteroDictLinear` ([#9493](https://github.com/pyg-team/pytorch_geometric/pull/9493))
- `Sequential` can now be properly pickled ([#9369](https://github.com/pyg-team/pytorch_geometric/pull/9369))
- Fixed `pickle.load` for jittable `MessagePassing` modules ([#9368](https://github.com/pyg-team/pytorch_geometric/pull/9368))
Expand Down
10 changes: 4 additions & 6 deletions examples/multi_gpu/papers100m_gcn_cugraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
from cugraph.gnn import (
cugraph_comms_create_unique_id,
cugraph_comms_init,
cugraph_comms_shutdown,
)
from ogb.nodeproppred import PygNodePropPredDataset
from torch.nn.parallel import DistributedDataParallel

Expand Down Expand Up @@ -55,6 +50,7 @@ def init_pytorch_worker(rank, world_size, cugraph_id):
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group('nccl', rank=rank, world_size=world_size)

from cugraph.gnn import cugraph_comms_init
cugraph_comms_init(rank=rank, world_size=world_size, uid=cugraph_id,
device=rank)

Expand Down Expand Up @@ -212,6 +208,7 @@ def run(rank, data, world_size, cugraph_id, model, epochs, batch_size, fan_out,
print(f"Total Training Runtime: {total_time - prep_time}s")
print(f"Total Program Runtime: {total_time}s")

from cugraph.gnn import cugraph_comms_shutdown
cugraph_comms_shutdown()
dist.destroy_process_group()

Expand Down Expand Up @@ -257,9 +254,10 @@ def run(rank, data, world_size, cugraph_id, model, epochs, batch_size, fan_out,
world_size = torch.cuda.device_count()
else:
world_size = args.n_devices
print(f"Using {world_size} many GPUs...")
print(f"Using {world_size} GPUs...")

# Create the uid needed for cuGraph comms
from cugraph.gnn import cugraph_comms_create_unique_id
cugraph_id = cugraph_comms_create_unique_id()

with tempfile.TemporaryDirectory(dir=args.tempdir_root) as tempdir:
Expand Down
7 changes: 5 additions & 2 deletions examples/ogbn_papers_100m_cugraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def train():
model.train()

total_loss = total_correct = total_examples = 0
for batch in tqdm(train_loader):
for i, batch in enumerate(train_loader):
batch = batch.cuda()
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)[:batch.batch_size]
Expand All @@ -160,14 +160,17 @@ def train():
total_correct += int(out.argmax(dim=-1).eq(y).sum())
total_examples += y.size(0)

if i % 10 == 0:
print(f"Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}")

return total_loss / total_examples, total_correct / total_examples

@torch.no_grad()
def test(loader):
model.eval()

total_correct = total_examples = 0
for batch in tqdm(loader):
for batch in loader:
batch = batch.cuda()
out = model(batch.x, batch.edge_index)[:batch.batch_size]
y = batch.y[:batch.batch_size].view(-1).to(torch.long)
Expand Down

0 comments on commit c809173

Please sign in to comment.