Skip to content

Commit

Permalink
Added the residual option for GATConv and GATv2Conv. (#9515)
Browse files Browse the repository at this point in the history
As discussed here
[#7555 (reply in thread),
the GAT paper mentioned using skip-connection to reach the reported
metrics on PPI dataset. This PR adds the option of `residual` to enable
skip-connection.

Specifically, we consider the following for adding residual:
1. Define the `actual_out_channels` = `heads` * `out_channels` if concat
else `out_channels`
2. If `in_channels` == `actual_out_channels`, then residual is just `x`
before the convolution. Otherwise, residual is a linear projection of
the input `x` onto the dimension of `actual_out_channels`.

For bipartite case where the input is (x_src, x_dst), the above
consideration will be applied for x_dst.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Jul 29, 2024
1 parent 4e878d1 commit 2ab9971
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added a `residual` option in `GATConv` and `GATv2Conv` ([#9515](https://github.com/pyg-team/pytorch_geometric/pull/9515))
- Added the `PatchTransformerAggregation` layer ([#9487](https://github.com/pyg-team/pytorch_geometric/pull/9487))
- Added the `nn.nlp.LLM` model ([#9462](https://github.com/pyg-team/pytorch_geometric/pull/9462))
- Added an example of training GNNs for a graph-level regression task ([#9070](https://github.com/pyg-team/pytorch_geometric/pull/9070))
Expand Down
16 changes: 7 additions & 9 deletions examples/ppi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,16 @@
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GATConv(train_dataset.num_features, 256, heads=4)
self.lin1 = torch.nn.Linear(train_dataset.num_features, 4 * 256)
self.conv2 = GATConv(4 * 256, 256, heads=4)
self.lin2 = torch.nn.Linear(4 * 256, 4 * 256)
self.conv1 = GATConv(train_dataset.num_features, 256, heads=4,
residual=True)
self.conv2 = GATConv(4 * 256, 256, heads=4, residual=True)
self.conv3 = GATConv(4 * 256, train_dataset.num_classes, heads=6,
concat=False)
self.lin3 = torch.nn.Linear(4 * 256, train_dataset.num_classes)
concat=False, residual=True)

def forward(self, x, edge_index):
x = F.elu(self.conv1(x, edge_index) + self.lin1(x))
x = F.elu(self.conv2(x, edge_index) + self.lin2(x))
x = self.conv3(x, edge_index) + self.lin3(x)
x = F.elu(self.conv1(x, edge_index))
x = F.elu(self.conv2(x, edge_index))
x = self.conv3(x, edge_index)
return x


Expand Down
7 changes: 4 additions & 3 deletions test/nn/conv/test_gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
from torch_geometric.utils import to_torch_csc_tensor


def test_gat_conv():
@pytest.mark.parametrize('residual', [False, True])
def test_gat_conv(residual):
x1 = torch.randn(4, 8)
x2 = torch.randn(2, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))

conv = GATConv(8, 32, heads=2)
conv = GATConv(8, 32, heads=2, residual=residual)
assert str(conv) == 'GATConv(8, 32, heads=2)'
out = conv(x1, edge_index)
assert out.size() == (4, 64)
Expand Down Expand Up @@ -113,7 +114,7 @@ def forward(
# Test bipartite message passing:
adj1 = to_torch_csc_tensor(edge_index, size=(4, 2))

conv = GATConv((8, 16), 32, heads=2)
conv = GATConv((8, 16), 32, heads=2, residual=residual)
assert str(conv) == 'GATConv((8, 16), 32, heads=2)'

out1 = conv((x1, x2), edge_index)
Expand Down
6 changes: 4 additions & 2 deletions test/nn/conv/test_gatv2_conv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Tuple

import pytest
import torch
from torch import Tensor

Expand All @@ -10,13 +11,14 @@
from torch_geometric.utils import to_torch_csc_tensor


def test_gatv2_conv():
@pytest.mark.parametrize('residual', [False, True])
def test_gatv2_conv(residual):
x1 = torch.randn(4, 8)
x2 = torch.randn(2, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))

conv = GATv2Conv(8, 32, heads=2)
conv = GATv2Conv(8, 32, heads=2, residual=residual)
assert str(conv) == 'GATv2Conv(8, 32, heads=2)'
out = conv(x1, edge_index)
assert out.size() == (4, 64)
Expand Down
37 changes: 33 additions & 4 deletions torch_geometric/nn/conv/gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ class GATConv(MessagePassing):
:obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"mean"`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
residual (bool, optional): If set to :obj:`True`, the layer will add
a learnable skip-connection. (default: :obj:`False`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
Expand Down Expand Up @@ -137,6 +139,7 @@ def __init__(
edge_dim: Optional[int] = None,
fill_value: Union[float, Tensor, str] = 'mean',
bias: bool = True,
residual: bool = False,
**kwargs,
):
kwargs.setdefault('aggr', 'add')
Expand All @@ -151,6 +154,7 @@ def __init__(
self.add_self_loops = add_self_loops
self.edge_dim = edge_dim
self.fill_value = fill_value
self.residual = residual

# In case we are operating in bipartite graphs, we apply separate
# transformations 'lin_src' and 'lin_dst' to source and target nodes:
Expand All @@ -176,10 +180,22 @@ def __init__(
self.lin_edge = None
self.register_parameter('att_edge', None)

if bias and concat:
self.bias = Parameter(torch.empty(heads * out_channels))
elif bias and not concat:
self.bias = Parameter(torch.empty(out_channels))
# The number of output channels:
total_out_channels = out_channels * (heads if concat else 1)

if residual:
self.res = Linear(
in_channels
if isinstance(in_channels, int) else in_channels[1],
total_out_channels,
bias=False,
weight_initializer='glorot',
)
else:
self.register_parameter('res', None)

if bias:
self.bias = Parameter(torch.empty(total_out_channels))
else:
self.register_parameter('bias', None)

Expand All @@ -195,6 +211,8 @@ def reset_parameters(self):
self.lin_dst.reset_parameters()
if self.lin_edge is not None:
self.lin_edge.reset_parameters()
if self.res is not None:
self.res.reset_parameters()
glorot(self.att_src)
glorot(self.att_dst)
glorot(self.att_edge)
Expand Down Expand Up @@ -270,11 +288,16 @@ def forward( # noqa: F811

H, C = self.heads, self.out_channels

res: Optional[Tensor] = None

# We first transform the input node features. If a tuple is passed, we
# transform source and target node features via separate weights:
if isinstance(x, Tensor):
assert x.dim() == 2, "Static graphs not supported in 'GATConv'"

if self.res is not None:
res = self.res(x)

if self.lin is not None:
x_src = x_dst = self.lin(x).view(-1, H, C)
else:
Expand All @@ -288,6 +311,9 @@ def forward( # noqa: F811
x_src, x_dst = x
assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'"

if x_dst is not None and self.res is not None:
res = self.res(x_dst)

if self.lin is not None:
# If the module is initialized as non-bipartite, we expect that
# source and destination node features have the same shape and
Expand Down Expand Up @@ -344,6 +370,9 @@ def forward( # noqa: F811
else:
out = out.mean(dim=1)

if res is not None:
out = out + res

if self.bias is not None:
out = out + self.bias

Expand Down
39 changes: 35 additions & 4 deletions torch_geometric/nn/conv/gatv2_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ class GATv2Conv(MessagePassing):
will be applied to the source and the target node of every edge,
*i.e.* :math:`\mathbf{\Theta}_{s} = \mathbf{\Theta}_{t}`.
(default: :obj:`False`)
residual (bool, optional): If set to :obj:`True`, the layer will add
a learnable skip-connection. (default: :obj:`False`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
Expand Down Expand Up @@ -141,6 +143,7 @@ def __init__(
fill_value: Union[float, Tensor, str] = 'mean',
bias: bool = True,
share_weights: bool = False,
residual: bool = False,
**kwargs,
):
super().__init__(node_dim=0, **kwargs)
Expand All @@ -154,6 +157,7 @@ def __init__(
self.add_self_loops = add_self_loops
self.edge_dim = edge_dim
self.fill_value = fill_value
self.residual = residual
self.share_weights = share_weights

if isinstance(in_channels, int):
Expand Down Expand Up @@ -181,10 +185,22 @@ def __init__(
else:
self.lin_edge = None

if bias and concat:
self.bias = Parameter(torch.empty(heads * out_channels))
elif bias and not concat:
self.bias = Parameter(torch.empty(out_channels))
# The number of output channels:
total_out_channels = out_channels * (heads if concat else 1)

if residual:
self.res = Linear(
in_channels
if isinstance(in_channels, int) else in_channels[1],
total_out_channels,
bias=False,
weight_initializer='glorot',
)
else:
self.register_parameter('res', None)

if bias:
self.bias = Parameter(torch.empty(total_out_channels))
else:
self.register_parameter('bias', None)

Expand All @@ -196,6 +212,8 @@ def reset_parameters(self):
self.lin_r.reset_parameters()
if self.lin_edge is not None:
self.lin_edge.reset_parameters()
if self.res is not None:
self.res.reset_parameters()
glorot(self.att)
zeros(self.bias)

Expand Down Expand Up @@ -255,10 +273,16 @@ def forward( # noqa: F811
"""
H, C = self.heads, self.out_channels

res: Optional[Tensor] = None

x_l: OptTensor = None
x_r: OptTensor = None
if isinstance(x, Tensor):
assert x.dim() == 2

if self.res is not None:
res = self.res(x)

x_l = self.lin_l(x).view(-1, H, C)
if self.share_weights:
x_r = x_l
Expand All @@ -267,6 +291,10 @@ def forward( # noqa: F811
else:
x_l, x_r = x[0], x[1]
assert x[0].dim() == 2

if x_r is not None and self.res is not None:
res = self.res(x_r)

x_l = self.lin_l(x_l).view(-1, H, C)
if x_r is not None:
x_r = self.lin_r(x_r).view(-1, H, C)
Expand Down Expand Up @@ -305,6 +333,9 @@ def forward( # noqa: F811
else:
out = out.mean(dim=1)

if res is not None:
out = out + res

if self.bias is not None:
out = out + self.bias

Expand Down

0 comments on commit 2ab9971

Please sign in to comment.