diff --git a/CHANGELOG.md b/CHANGELOG.md index 271acd853b5d..2696612f896e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Consolidate ogbn-products and papers100m basic examples into a single ogbn_train.py with additional improvements ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467)) + ### Changed ### Deprecated diff --git a/examples/README.md b/examples/README.md index 64fc2dffdcdd..bed728ee0ea7 100644 --- a/examples/README.md +++ b/examples/README.md @@ -9,9 +9,8 @@ For a simple link prediction example, see [`link_pred.py`](./link_pred.py). For examples on [Open Graph Benchmark](https://ogb.stanford.edu/) datasets, see the `ogbn_*.py` examples: -- [`ogbn_products_sage.py`](./ogbn_products_sage.py) and [`ogbn_products_gat.py`](./ogbn_products_gat.py) show how to train [`GraphSAGE`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GraphSAGE.html) and [`GAT`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GAT.html) models on the `ogbn-products` dataset. +- [`ogbn_train.py`](./ogbn_train.py) is an example for training a GNN on the large-scale `ogbn-papers100m` dataset, containing approximately ~1.6B edges or the medium scale `ogbn-products` dataset, ~62M edges. - [`ogbn_proteins_deepgcn.py`](./ogbn_proteins_deepgcn.py) is an example to showcase how to train deep GNNs on the `ogbn-proteins` dataset. -- [`ogbn_papers_100m.py`](./ogbn_papers_100m.py) is an example for training a GNN on the large-scale `ogbn-papers100m` dataset, containing approximately ~1.6B edges. - [`ogbn_papers_100m_cugraph.py`](./ogbn_papers_100m_cugraph.py) shows how to accelerate the `ogbn-papers100m` workflow using [CuGraph](https://github.com/rapidsai/cugraph). For examples on using `torch.compile`, see the examples under [`examples/compile`](./compile). diff --git a/examples/ogbn_papers_100m.py b/examples/ogbn_papers_100m.py deleted file mode 100644 index d7ad3d01f990..000000000000 --- a/examples/ogbn_papers_100m.py +++ /dev/null @@ -1,126 +0,0 @@ -# Reaches around 0.7870 ± 0.0036 test accuracy. - -import argparse -import os.path as osp - -import torch -import torch.nn.functional as F -from ogb.nodeproppred import PygNodePropPredDataset -from tqdm import tqdm - -from torch_geometric.loader import NeighborLoader -from torch_geometric.nn import SAGEConv -from torch_geometric.utils import to_undirected - -parser = argparse.ArgumentParser() -parser.add_argument('--device', type=str, default='cuda') -parser.add_argument('--epochs', type=int, default=10) -parser.add_argument('--num_layers', type=int, default=3) -parser.add_argument('--batch_size', type=int, default=1024) -parser.add_argument('--num_neighbors', type=int, default=10) -parser.add_argument('--channels', type=int, default=256) -parser.add_argument('--lr', type=float, default=0.003) -parser.add_argument('--dropout', type=float, default=0.5) -parser.add_argument('--num_workers', type=int, default=12) -args = parser.parse_args() - -root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'papers100') -dataset = PygNodePropPredDataset('ogbn-papers100M', root) -split_idx = dataset.get_idx_split() - -data = dataset[0] -data.edge_index = to_undirected(data.edge_index, reduce="mean") - -train_loader = NeighborLoader( - data, - input_nodes=split_idx['train'], - num_neighbors=[args.num_neighbors] * args.num_layers, - batch_size=args.batch_size, - shuffle=True, - num_workers=args.num_workers, - persistent_workers=args.num_workers > 0, -) -val_loader = NeighborLoader( - data, - input_nodes=split_idx['valid'], - num_neighbors=[args.num_neighbors] * args.num_layers, - batch_size=args.batch_size, - num_workers=args.num_workers, - persistent_workers=args.num_workers > 0, -) -test_loader = NeighborLoader( - data, - input_nodes=split_idx['test'], - num_neighbors=[args.num_neighbors] * args.num_layers, - batch_size=args.batch_size, - num_workers=args.num_workers, - persistent_workers=args.num_workers > 0, -) - - -class SAGE(torch.nn.Module): - def __init__(self, in_channels, out_channels): - super().__init__() - - self.convs = torch.nn.ModuleList() - self.convs.append(SAGEConv(in_channels, args.channels)) - for _ in range(args.num_layers - 2): - self.convs.append(SAGEConv(args.channels, args.channels)) - self.convs.append(SAGEConv(args.channels, out_channels)) - - def forward(self, x, edge_index): - for i, conv in enumerate(self.convs): - x = conv(x, edge_index) - if i != args.num_layers - 1: - x = x.relu() - x = F.dropout(x, p=args.dropout, training=self.training) - return x - - -model = SAGE(dataset.num_features, dataset.num_classes).to(args.device) -optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) - - -def train(): - model.train() - - total_loss = total_correct = total_examples = 0 - for batch in tqdm(train_loader): - batch = batch.to(args.device) - optimizer.zero_grad() - out = model(batch.x, batch.edge_index)[:batch.batch_size] - y = batch.y[:batch.batch_size].view(-1).to(torch.long) - loss = F.cross_entropy(out, y) - loss.backward() - optimizer.step() - - total_loss += float(loss) * y.size(0) - total_correct += int(out.argmax(dim=-1).eq(y).sum()) - total_examples += y.size(0) - - return total_loss / total_examples, total_correct / total_examples - - -@torch.no_grad() -def test(loader): - model.eval() - - total_correct = total_examples = 0 - for batch in tqdm(loader): - batch = batch.to(args.device) - out = model(batch.x, batch.edge_index)[:batch.batch_size] - y = batch.y[:batch.batch_size].view(-1).to(torch.long) - - total_correct += int(out.argmax(dim=-1).eq(y).sum()) - total_examples += y.size(0) - - return total_correct / total_examples - - -for epoch in range(1, args.epochs + 1): - loss, train_acc = train() - print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}') - val_acc = test(val_loader) - print(f'Epoch {epoch:02d}, Val Acc: {val_acc:.4f}') - test_acc = test(test_loader) - print(f'Epoch {epoch:02d}, Test Acc: {test_acc:.4f}') diff --git a/examples/ogbn_products_gat.py b/examples/ogbn_products_gat.py deleted file mode 100644 index 06d01e94b1f8..000000000000 --- a/examples/ogbn_products_gat.py +++ /dev/null @@ -1,185 +0,0 @@ -# Reaches around 0.7945 ± 0.0059 test accuracy. - -import os.path as osp - -import torch -import torch.nn.functional as F -from ogb.nodeproppred import Evaluator, PygNodePropPredDataset -from torch.nn import Linear as Lin -from tqdm import tqdm - -from torch_geometric.loader import NeighborLoader -from torch_geometric.nn import GATConv - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'products') -dataset = PygNodePropPredDataset('ogbn-products', root) -split_idx = dataset.get_idx_split() -evaluator = Evaluator(name='ogbn-products') -data = dataset[0].to(device, 'x', 'y') - -train_loader = NeighborLoader( - data, - input_nodes=split_idx['train'], - num_neighbors=[10, 10, 5], - batch_size=512, - shuffle=True, - num_workers=12, - persistent_workers=True, -) -subgraph_loader = NeighborLoader( - data, - input_nodes=None, - num_neighbors=[-1], - batch_size=2048, - num_workers=12, - persistent_workers=True, -) - - -class GAT(torch.nn.Module): - def __init__(self, in_channels, hidden_channels, out_channels, num_layers, - heads): - super().__init__() - - self.num_layers = num_layers - - self.convs = torch.nn.ModuleList() - self.convs.append(GATConv(dataset.num_features, hidden_channels, - heads)) - for _ in range(num_layers - 2): - self.convs.append( - GATConv(heads * hidden_channels, hidden_channels, heads)) - self.convs.append( - GATConv(heads * hidden_channels, out_channels, heads, - concat=False)) - - self.skips = torch.nn.ModuleList() - self.skips.append(Lin(dataset.num_features, hidden_channels * heads)) - for _ in range(num_layers - 2): - self.skips.append( - Lin(hidden_channels * heads, hidden_channels * heads)) - self.skips.append(Lin(hidden_channels * heads, out_channels)) - - def reset_parameters(self): - for conv in self.convs: - conv.reset_parameters() - for skip in self.skips: - skip.reset_parameters() - - def forward(self, x, edge_index): - for i, (conv, skip) in enumerate(zip(self.convs, self.skips)): - x = conv(x, edge_index) + skip(x) - if i != self.num_layers - 1: - x = F.elu(x) - x = F.dropout(x, p=0.5, training=self.training) - return x - - def inference(self, x_all): - pbar = tqdm(total=x_all.size(0) * self.num_layers) - pbar.set_description('Evaluating') - - # Compute representations of nodes layer by layer, using *all* - # available edges. This leads to faster computation in contrast to - # immediately computing the final representations of each batch. - for i in range(self.num_layers): - xs = [] - for batch in subgraph_loader: - x = x_all[batch.n_id].to(device) - edge_index = batch.edge_index.to(device) - x = self.convs[i](x, edge_index) + self.skips[i](x) - x = x[:batch.batch_size] - if i != self.num_layers - 1: - x = F.elu(x) - xs.append(x.cpu()) - - pbar.update(batch.batch_size) - - x_all = torch.cat(xs, dim=0) - - pbar.close() - - return x_all - - -model = GAT(dataset.num_features, 128, dataset.num_classes, num_layers=3, - heads=4).to(device) - - -def train(epoch): - model.train() - - pbar = tqdm(total=split_idx['train'].size(0)) - pbar.set_description(f'Epoch {epoch:02d}') - - total_loss = total_correct = 0 - for batch in train_loader: - optimizer.zero_grad() - out = model(batch.x, batch.edge_index.to(device))[:batch.batch_size] - y = batch.y[:batch.batch_size].squeeze() - loss = F.cross_entropy(out, y) - loss.backward() - optimizer.step() - - total_loss += float(loss) - total_correct += int(out.argmax(dim=-1).eq(y).sum()) - pbar.update(batch.batch_size) - - pbar.close() - - loss = total_loss / len(train_loader) - approx_acc = total_correct / split_idx['train'].size(0) - - return loss, approx_acc - - -@torch.no_grad() -def test(): - model.eval() - - out = model.inference(data.x) - - y_true = data.y.cpu() - y_pred = out.argmax(dim=-1, keepdim=True) - - train_acc = evaluator.eval({ - 'y_true': y_true[split_idx['train']], - 'y_pred': y_pred[split_idx['train']], - })['acc'] - val_acc = evaluator.eval({ - 'y_true': y_true[split_idx['valid']], - 'y_pred': y_pred[split_idx['valid']], - })['acc'] - test_acc = evaluator.eval({ - 'y_true': y_true[split_idx['test']], - 'y_pred': y_pred[split_idx['test']], - })['acc'] - - return train_acc, val_acc, test_acc - - -test_accs = [] -for run in range(1, 11): - print(f'\nRun {run:02d}:\n') - - model.reset_parameters() - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - - best_val_acc = final_test_acc = 0.0 - for epoch in range(1, 101): - loss, acc = train(epoch) - print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train: {acc:.4f}') - - if epoch > 50 and epoch % 10 == 0: - train_acc, val_acc, test_acc = test() - print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' - f'Test: {test_acc:.4f}') - - if val_acc > best_val_acc: - best_val_acc = val_acc - final_test_acc = test_acc - test_accs.append(final_test_acc) - -test_acc = torch.tensor(test_accs) -print('============================') -print(f'Final Test: {test_acc.mean():.4f} ± {test_acc.std():.4f}') diff --git a/examples/ogbn_products_sage.py b/examples/ogbn_products_sage.py deleted file mode 100644 index 4f1a442ff30a..000000000000 --- a/examples/ogbn_products_sage.py +++ /dev/null @@ -1,170 +0,0 @@ -# Reaches around 0.7870 ± 0.0036 test accuracy. - -import os.path as osp - -import torch -import torch.nn.functional as F -from ogb.nodeproppred import Evaluator, PygNodePropPredDataset -from tqdm import tqdm - -from torch_geometric.loader import NeighborLoader -from torch_geometric.nn import SAGEConv - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'products') -dataset = PygNodePropPredDataset('ogbn-products', root) -split_idx = dataset.get_idx_split() -evaluator = Evaluator(name='ogbn-products') -data = dataset[0].to(device, 'x', 'y') - -train_loader = NeighborLoader( - data, - input_nodes=split_idx['train'], - num_neighbors=[15, 10, 5], - batch_size=1024, - shuffle=True, - num_workers=12, - persistent_workers=True, -) -subgraph_loader = NeighborLoader( - data, - input_nodes=None, - num_neighbors=[-1], - batch_size=4096, - num_workers=12, - persistent_workers=True, -) - - -class SAGE(torch.nn.Module): - def __init__(self, in_channels, hidden_channels, out_channels, num_layers): - super().__init__() - - self.num_layers = num_layers - - self.convs = torch.nn.ModuleList() - self.convs.append(SAGEConv(in_channels, hidden_channels)) - for _ in range(num_layers - 2): - self.convs.append(SAGEConv(hidden_channels, hidden_channels)) - self.convs.append(SAGEConv(hidden_channels, out_channels)) - - def reset_parameters(self): - for conv in self.convs: - conv.reset_parameters() - - def forward(self, x, edge_index): - for i, conv in enumerate(self.convs): - x = conv(x, edge_index) - if i != self.num_layers - 1: - x = x.relu() - x = F.dropout(x, p=0.5, training=self.training) - return x - - def inference(self, x_all): - pbar = tqdm(total=x_all.size(0) * self.num_layers) - pbar.set_description('Evaluating') - - # Compute representations of nodes layer by layer, using *all* - # available edges. This leads to faster computation in contrast to - # immediately computing the final representations of each batch. - for i in range(self.num_layers): - xs = [] - for batch in subgraph_loader: - x = x_all[batch.n_id].to(device) - edge_index = batch.edge_index.to(device) - x = self.convs[i](x, edge_index) - x = x[:batch.batch_size] - if i != self.num_layers - 1: - x = x.relu() - xs.append(x.cpu()) - - pbar.update(batch.batch_size) - - x_all = torch.cat(xs, dim=0) - - pbar.close() - - return x_all - - -model = SAGE(dataset.num_features, 256, dataset.num_classes, num_layers=3) -model = model.to(device) - - -def train(epoch): - model.train() - - pbar = tqdm(total=split_idx['train'].size(0)) - pbar.set_description(f'Epoch {epoch:02d}') - - total_loss = total_correct = 0 - for batch in train_loader: - optimizer.zero_grad() - out = model(batch.x, batch.edge_index.to(device))[:batch.batch_size] - y = batch.y[:batch.batch_size].squeeze() - loss = F.cross_entropy(out, y) - loss.backward() - optimizer.step() - - total_loss += float(loss) - total_correct += int(out.argmax(dim=-1).eq(y).sum()) - pbar.update(batch.batch_size) - - pbar.close() - - loss = total_loss / len(train_loader) - approx_acc = total_correct / split_idx['train'].size(0) - - return loss, approx_acc - - -@torch.no_grad() -def test(): - model.eval() - - out = model.inference(data.x) - - y_true = data.y.cpu() - y_pred = out.argmax(dim=-1, keepdim=True) - - train_acc = evaluator.eval({ - 'y_true': y_true[split_idx['train']], - 'y_pred': y_pred[split_idx['train']], - })['acc'] - val_acc = evaluator.eval({ - 'y_true': y_true[split_idx['valid']], - 'y_pred': y_pred[split_idx['valid']], - })['acc'] - test_acc = evaluator.eval({ - 'y_true': y_true[split_idx['test']], - 'y_pred': y_pred[split_idx['test']], - })['acc'] - - return train_acc, val_acc, test_acc - - -test_accs = [] -for run in range(1, 11): - print(f'\nRun {run:02d}:\n') - - model.reset_parameters() - optimizer = torch.optim.Adam(model.parameters(), lr=0.003) - - best_val_acc = final_test_acc = 0.0 - for epoch in range(1, 21): - loss, acc = train(epoch) - print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train: {acc:.4f}') - - if epoch > 5: - train_acc, val_acc, test_acc = test() - print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' - f'Test: {test_acc:.4f}') - - if val_acc > best_val_acc: - best_val_acc = val_acc - final_test_acc = test_acc - test_accs.append(final_test_acc) - -test_acc = torch.tensor(test_accs) -print('============================') -print(f'Final Test: {test_acc.mean():.4f} ± {test_acc.std():.4f}') diff --git a/examples/ogbn_train.py b/examples/ogbn_train.py new file mode 100644 index 000000000000..22308a171708 --- /dev/null +++ b/examples/ogbn_train.py @@ -0,0 +1,267 @@ +import argparse +import os.path as osp +import time +from typing import Tuple + +import psutil +import torch +import torch.nn.functional as F +from ogb.nodeproppred import Evaluator, PygNodePropPredDataset +from torch import Tensor +from tqdm import tqdm + +from torch_geometric.loader import NeighborLoader +from torch_geometric.utils import to_undirected + +parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) +parser.add_argument( + '--dataset', + type=str, + default='ogbn-papers100M', + choices=['ogbn-papers100M', 'ogbn-products'], + help='Dataset name.', +) +parser.add_argument( + '--dataset_dir', + type=str, + default='./data', + help='Root directory of dataset.', +) +parser.add_argument( + "--dataset_subdir", + type=str, + default="ogb-papers100M", + help="directory of dataset.", +) +parser.add_argument( + '--use_gat', + action='store_true', + help='Whether or not to use GAT model', +) +parser.add_argument( + '--verbose', + action='store_true', + help='Whether or not to generate statistical report', +) +parser.add_argument('--device', type=str, default='cuda') +parser.add_argument('-e', '--epochs', type=int, default=10, + help='number of training epochs.') +parser.add_argument('--num_layers', type=int, default=3, + help='number of layers.') +parser.add_argument('--num_heads', type=int, default=2, + help='number of heads for GAT model.') +parser.add_argument('-b', '--batch_size', type=int, default=1024, + help='batch size.') +parser.add_argument('--num_workers', type=int, default=12, + help='number of workers.') +parser.add_argument('--fan_out', type=int, default=10, + help='number of neighbors in each layer') +parser.add_argument('--hidden_channels', type=int, default=256, + help='number of hidden channels.') +parser.add_argument('--lr', type=float, default=0.003) +parser.add_argument('--wd', type=float, default=0.0, + help='weight decay for the optimizer') +parser.add_argument('--dropout', type=float, default=0.5) +parser.add_argument( + '--use_directed_graph', + action='store_true', + help='Whether or not to use directed graph', +) +args = parser.parse_args() +if "papers" in str(args.dataset) and (psutil.virtual_memory().total / + (1024**3)) < 390: + print("Warning: may not have enough RAM to use this many GPUs.") + print("Consider upgrading RAM if an error occurs.") + print("Estimated RAM Needed: ~390GB.") +verbose = args.verbose +if verbose: + wall_clock_start = time.perf_counter() + if args.use_gat: + print(f"Training {args.dataset} with GAT model.") + else: + print(f"Training {args.dataset} with GraphSage model.") + +if not torch.cuda.is_available(): + args.device = "cpu" +device = torch.device(args.device) + +num_epochs = args.epochs +num_layers = args.num_layers +num_workers = args.num_workers +num_hidden_channels = args.hidden_channels +batch_size = args.batch_size + +root = osp.join(args.dataset_dir, args.dataset_subdir) +print('The root is: ', root) +dataset = PygNodePropPredDataset(name=args.dataset, root=root) +split_idx = dataset.get_idx_split() +evaluator = Evaluator(name=args.dataset) + +data = dataset[0] +if not args.use_directed_graph: + data.edge_index = to_undirected(data.edge_index, reduce='mean') + +data.to(device, 'x', 'y') + +train_loader = NeighborLoader( + data, + input_nodes=split_idx['train'], + num_neighbors=[args.fan_out] * num_layers, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + persistent_workers=True, +) +val_loader = NeighborLoader( + data, + input_nodes=split_idx['valid'], + num_neighbors=[args.fan_out] * num_layers, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + persistent_workers=True, +) +test_loader = NeighborLoader( + data, + input_nodes=split_idx['test'], + num_neighbors=[args.fan_out] * num_layers, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + persistent_workers=True, +) + + +def train(epoch: int) -> Tuple[Tensor, float]: + model.train() + + pbar = tqdm(total=split_idx['train'].size(0)) + pbar.set_description(f'Epoch {epoch:02d}') + + total_loss = total_correct = 0 + for batch in train_loader: + optimizer.zero_grad() + out = model(batch.x, batch.edge_index.to(device))[:batch.batch_size] + y = batch.y[:batch.batch_size].squeeze().to(torch.long) + loss = F.cross_entropy(out, y) + loss.backward() + optimizer.step() + + total_loss += float(loss) + total_correct += int(out.argmax(dim=-1).eq(y).sum()) + pbar.update(batch.batch_size) + + pbar.close() + loss = total_loss / len(train_loader) + approx_acc = total_correct / split_idx['train'].size(0) + return loss, approx_acc + + +@torch.no_grad() +def test(loader: NeighborLoader, val_steps=None) -> float: + model.eval() + + total_correct = total_examples = 0 + for i, batch in enumerate(loader): + if val_steps is not None and i >= val_steps: + break + batch = batch.to(device) + batch_size = batch.num_sampled_nodes[0] + out = model(batch.x, batch.edge_index)[:batch_size] + pred = out.argmax(dim=-1) + y = batch.y[:batch_size].view(-1).to(torch.long) + + total_correct += int((pred == y).sum()) + total_examples += y.size(0) + + return total_correct / total_examples + + +if args.use_gat: + from torch_geometric.nn.models import GAT + model = GAT( + in_channels=dataset.num_features, + hidden_channels=num_hidden_channels, + num_layers=num_layers, + out_channels=dataset.num_classes, + dropout=args.dropout, + heads=args.num_heads, + ) +else: + from torch_geometric.nn.models import GraphSAGE + model = GraphSAGE( + in_channels=dataset.num_features, + hidden_channels=num_hidden_channels, + num_layers=num_layers, + out_channels=dataset.num_classes, + dropout=args.dropout, + ) + +model = model.to(device) +optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, + weight_decay=args.wd) + +if verbose: + prep_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total time before training begins (prep_time)=", prep_time, + "seconds") + print("Training...") + +test_accs = [] +val_accs = [] +times = [] +train_times = [] +inference_times = [] +best_val = best_test = 0. +start = time.time() + +model.reset_parameters() +optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + +for epoch in range(1, num_epochs + 1): + train_start = time.time() + loss, _ = train(epoch) + train_end = time.time() + train_times.append(train_end - train_start) + + inference_start = time.time() + val_acc = test(val_loader) + test_acc = test(test_loader) + + inference_times.append(time.time() - inference_start) + test_accs.append(test_acc) + val_accs.append(val_acc) + print( + f'Epoch {epoch:02d}, Loss: {loss:.4f}, Train Time: {train_end - train_start:.4f}s' + ) + print(f'Val: {val_acc * 100.0:.2f}%,', f'Test: {test_acc * 100.0:.2f}%') + + if val_acc > best_val: + best_val = val_acc + if test_acc > best_test: + best_test = test_acc + times.append(time.time() - train_start) + +if verbose: + test_acc = torch.tensor(test_accs) + val_acc = torch.tensor(val_accs) + print('============================') + print("Average Epoch Time on training: {:.4f}".format( + torch.tensor(train_times).mean())) + print("Average Epoch Time on inference: {:.4f}".format( + torch.tensor(inference_times).mean())) + print(f"Average Epoch Time: {torch.tensor(times).mean():.4f}") + print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") + print(f'Final Test: {test_acc.mean():.4f} ± {test_acc.std():.4f}') + print(f'Final Validation: {val_acc.mean():.4f} ± {val_acc.std():.4f}') + print(f"Best validation accuracy: {best_val:.4f}") + print(f"Best testing accuracy: {best_test:.4f}") + +if verbose: + print("Testing...") +test_final_acc = test(test_loader) +print(f'Test Accuracy: {100.0 * test_final_acc:.2f}%') +if verbose: + total_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total Program Runtime (total_time) =", total_time, "seconds")