Contextualized Messages Boost Graph Representations

Contextualized Messages Boost Graph Representations

Our paper Contextualized Messages Boost Graph Representations has recently been posted on arXiv (April 2025 update: published in TMLR). While the main theoretical results are presented in the paper, this post aims to provide an intuitive understanding of the motivation and findings. In summary, the paper theoretically justifies the need for anisotropic (i.e., a function of both the features of the center and neighboring nodes) and dynamic (i.e., a universal function approximator) message functions in graph neural networks (GNNs) and proposes a simple and computationally efficient model, the soft-isomorphic relational graph convolution network (SIR-GCN), satisfying this requirement that empirically outperforms comparable models.

  1. Graph Neural Networks
    1. Graph Convolution Network (GCN)
    2. Graph Sample and Aggregate (GraphSAGE)
    3. Graph Attention Network (GAT)
    4. Graph Isomorphism Network (GIN)
    5. Principal Neighborhood Aggregation (PNA)
  2. Soft-Isomorphic Relational Graph Convolution Network (SIR-GCN)
    1. Implementation
  3. Experimental Results
  4. Conclusion

Before delving into the paper and its main results, let us first review GNNs and how they operate.

Graph Neural Networks

Let G=(V,E)\mathcal{G} = \left(\mathcal{V}, \mathcal{E}\right) be a graph and N(u)V\mathcal{N}(u) \subseteq \mathcal{V} the set of nodes adjacent to node uVu \in \mathcal{V}. Suppose H\mathcal{H} is the space of node features and huH\boldsymbol{h_u} \in \mathcal{H} is the feature of node uu. A GNN following the message-passing scheme first aggregates the features of each node’s neighbors and then combines them to create an updated node feature. This can be expressed mathematically as Hu={ ⁣ ⁣{hv:vN(u)} ⁣ ⁣}au=AGG(Hu)hu=COMB(hu,au)\begin{aligned} \boldsymbol{H_u} &= \left\{\!\!\left\{\boldsymbol{h_v} : v \in \mathcal{N}(u)\right\}\!\!\right\} \\ \boldsymbol{a_u} &= AGG\left(\boldsymbol{H_u}\right) \\ \boldsymbol{h^*_u} &= COMB\left(\boldsymbol{h_u}, \boldsymbol{a_u}\right) \end{aligned}

where AGG and COMB are some aggregation and combination strategies, respectively, Hu\boldsymbol{H_u} is the multiset of neighborhood features for node uu, au\boldsymbol{a_u} is the aggregated neighborhood feature for node uu, and hu\boldsymbol{h^*_u} is the updated feature for node uu. An illustration is provided below.

MPNN

Message-passing scheme for GNNs.

Since AGG takes arbitrary-sized multisets of neighborhood features as input and transforms them into a single feature, it may be considered a hash function. Hence, aggregation and hash functions shall be used interchangeably.

To provide additional context, five key GNNs in literature are introduced below.

Graph Convolution Network (GCN)

The graph convolution network (GCN) is one of the classical GNNs in literature. It can be expressed mathematically as hu=vN(u)1N(u)N(v)Whv\boldsymbol{h^*_u} = \sum_{v \in \mathcal{N}(u)} \dfrac{1}{\sqrt{\left|\mathcal{N}(u)\right|}\sqrt{\left|\mathcal{N}(v)\right|}} \boldsymbol{W} \boldsymbol{h_v}

where W\boldsymbol{W} represents a linear transformation of the neighborhood features hv\boldsymbol{h_v}.

Graph Sample and Aggregate (GraphSAGE)

The graph sample and aggregate (GraphSAGE) is another classic GNN designed for inductive representation learning. GraphSAGE with max pooling can be expressed as hu=maxvN(u){σ(Whv+b)}\boldsymbol{h^*_u} = \max_{v \in \mathcal{N}(u)} \left\{\sigma\left(\boldsymbol{W}\boldsymbol{h_v} + \boldsymbol{b}\right)\right\}

where W\boldsymbol{W} and b\boldsymbol{b} represents a linear transformation of the neighborhood features hv\boldsymbol{h_v} and σ\sigma is a non-linear activation function.

Graph Attention Network (GAT)

The graph attention network (GAT) leverages the popular attention mechanism to dynamically assign weights to each neighborhood feature. GATv2 can be expressed as hu=vN(u)αu,vWhvαu,v=exp(eu,v)wN(u)exp(eu,w)eu,v=aLeakyReLU(WQhu+WKhv)\begin{aligned} \boldsymbol{h^*_u} &= \sum_{v \in \mathcal{N}(u)} \alpha_{u,v} \cdot \boldsymbol{W} \boldsymbol{h_v} \\ \alpha_{u,v} &= \dfrac{\exp\left(e_{u,v}\right)}{\sum_{w \in \mathcal{N}(u)} \exp\left(e_{u,w}\right)} \\ e_{u,v} &= \boldsymbol{a}^\top \text{LeakyReLU}\left(\boldsymbol{W_Q}\boldsymbol{h_u} + \boldsymbol{W_K}\boldsymbol{h_v}\right) \end{aligned}

where W\boldsymbol{W} represents a linear transformation of the neighborhood features hv\boldsymbol{h_v} and a,WQ,WK\boldsymbol{a}, \boldsymbol{W_Q}, \boldsymbol{W_K} jointly represents the attention mechanism.

Graph Isomorphism Network (GIN)

The graph isomorphism network (GIN) is designed to be a maximally powerful GNN when the space of node features is countable. It is based on the Weisfeiler Lehman (WL) graph isomorphism test. GIN can be expressed as hu=MLP((1+ϵ)hu+vN(u)hv)\boldsymbol{h^*_u} = MLP\left((1 + \epsilon) \cdot \boldsymbol{h_u} + \sum_{v \in \mathcal{N}(u)} \boldsymbol{h_v}\right)

where ϵ\epsilon is a scalar that may be learnable and MLP is a multi-layer perceptron.

For more details about countable and uncountable sets, see Wikipedia.

Principal Neighborhood Aggregation (PNA)

The principal neighborhood aggregation (PNA) is a recently proposed GNN designed to be maximally expressive when the space of node features is uncountable. It can be expressed as hu=vN(u)[WQhu+WKhv]\boldsymbol{h^*_u} = \bigoplus_{v \in \mathcal{N}(u)} \left[\boldsymbol{W_Q} \boldsymbol{h_u} + \boldsymbol{W_K} \boldsymbol{h_v}\right]

where \bigoplus comprises multiple aggregators (e.g., mean, standard deviation, max, and min) and WQ,WK\boldsymbol{W_Q}, \boldsymbol{W_K} are linear transformations of the center node features hu\boldsymbol{h_u} and neighborhood features hv\boldsymbol{h_v}, respectively.

Technically speaking, PNA also introduces scalars as part of the aggregators \bigoplus. We ignore this for ease of discussion.

Soft-Isomorphic Relational Graph Convolution Network (SIR-GCN)

Based on the models introduced above, we can formulate the hash function in the message-passing framework as G(H)=hHg(h)G\left(\boldsymbol{H}\right) = \sum_{\boldsymbol{h} \in \boldsymbol{H}} g\left(\boldsymbol{h}\right)

where gg is the feature map or message function transforming the features of neighboring nodes hH\boldsymbol{h} \in \boldsymbol{H}. Notably, the majority of existing GNNs simply use a linear message function gg, resulting in a linear hash function with respect to the neighborhood features.

Intuitively, the sum aggregation may also be replaced with the mean, symmetric mean, or max aggregation.

To illustrate the limitations of existing models, consider node uu with two neighbors v1v_1 and v2v_2 and the task of anomaly detection on the scalar node features hv1\boldsymbol{h_{v_1}} and hv2\boldsymbol{h_{v_2}} representing zero-mean scores. If gg is linear, the corresponding contour plot of the hash function GG in Figure (a) below highlights collisions—instances where distinct inputs produce identical outputs—between dissimilar multisets of neighborhood features, resulting in aggregated neighborhood features that are less useful for the task. For instance, both Hu(1)={0,0}\boldsymbol{H_{u}^{(1)}} = \left\{0, 0\right\} and Hu(2)={10,10}\boldsymbol{H_{u}^{(2)}} = \left\{-10, 10\right\} will produce an aggregated neighborhood feature of au=0+0=10+10=0\boldsymbol{a_{u}} = 0 + 0 = -10 + 10 = 0 even if they are intuitively dissimilar.

HashCollision

Hash collisions in GNNs.

In contrast, for more complex message functions gg, the corresponding hash function GG becomes non-trivial and possibly more meaningful. For instance, if g(h)=h2g\left(\boldsymbol{h}\right) = - \boldsymbol{h}^2, then the corresponding hash function GG in Figure (b) produces hash collisions that are more useful and intuitive for detecting anomalous scores. In practice, this may be achieved by applying activation functions or MLPs before or after each GNN layer to induce non-linearity in gg.

Nevertheless, this approach still has its limitations. Specifically, since the message function gg is solely a function of the neighborhood features hvHu\boldsymbol{h_v} \in \boldsymbol{H_u}, the complex relationship between neighboring nodes cannot be learned. Figures (c) and (d) provide intuition for the necessity of gg to consider both the features of the center and neighboring nodes: for nodes uuVu \neq u' \in \mathcal{V} having the same two neighbors v1v_1 and v2v_2 with huhu\boldsymbol{h_u} \neq \boldsymbol{h_{u'}}, the introduction of the bias term 0\boldsymbol{0} and 1\boldsymbol{1} (assuming a function of the center node features hu\boldsymbol{h_u} and hu\boldsymbol{h_{u'}}) results in distinct aggregated neighborhood features auau\boldsymbol{a_u} \neq \boldsymbol{a_{u'}} even if the multiset of neighborhood features are identical Hu=Hu\boldsymbol{H_u} = \boldsymbol{H_{u'}}. This approach makes the message function gg and hash function GG adaptive with respect to the features of the center node uu, allowing it to learn the relationship between nodes connected by an edge to produce contextualized aggregated neighborhood features.

In the event that hu=hu\boldsymbol{h_u} = \boldsymbol{h_{u'}} for nodes uuu \neq u', one may inject stochasticity into the node features to distinguish between nodes with identical features and neighborhood features with high probability (Sato et al.).

In line with this observation, we emphasize the need for a dynamic (i.e., a universal function approximator) message function that takes both the features of the center and neighboring nodes as input (i.e., anisotropic). By the universal approximation theorem, this message function gg may be modeled as a two-layer MLP resulting in the soft-isomorphic relational graph convolution network (SIR-GCN) hu=vN(u)WR σ(WQhu+WKhv)\boldsymbol{h_u^*} = \sum_{v \in \mathcal{N}(u)} \boldsymbol{W_R} ~ \sigma\left(\boldsymbol{W_Q} \boldsymbol{h_u} + \boldsymbol{W_K} \boldsymbol{h_v}\right)

where σ\sigma is a non-linear activation function, WQ,WKRdhidden×din\boldsymbol{W_Q}, \boldsymbol{W_K} \in \mathbb{R}^{d_{\text{hidden}} \times d_{\text{in}}}, and WRRdout×dhidden\boldsymbol{W_R} \in \mathbb{R}^{d_{\text{out}} \times d_{\text{hidden}}}. These jointly represent a contextualized (i.e., anisotropic and dynamic) message function. Leveraging linearity, the proposed model has a computational complexity of O(V×dhidden×din+E×dhidden+V×dout×dhidden)\mathcal{O}\left(\left|\mathcal{V}\right| \times d_{\text{hidden}} \times d_{\text{in}} + \left|\mathcal{E}\right| \times d_{\text{hidden}} + \left|\mathcal{V}\right| \times d_{\text{out}} \times d_{\text{hidden}}\right)

with computational efficiency achieved by the application of only an activation function along edges.

Our proposed model is analogous to the recently proposed Kolmogorov-Arnold Networks (KANs). Similar to KANs, instead of applying activation functions on nodes, SIR-GCN applies non-linearities along edges. This allows the model to “reason along an edge.”

Our paper also presents a mathematical discussion on the relationship between SIR-GCN and the key GNNs introduced above, establishing the former as a generalization of classical GNN methodologies. This highlights the novelty of SIR-GCN as the first GNN model to incorporate both anisotropic and dynamic messages. The discussion further demonstrates that SIR-GCN is comparable to a modified WL test in terms of graph isomorphism representational capability, thereby inheriting the theoretical properties of the WL test.

Implementation

The code for SIR-GCN using the Deep Graph Library (DGL) framework is provided below.

import torch
from torch import nn
from dgl import function as fn
from dgl.utils import expand_as_pair

class SIRConv(nn.Module):
    r"""Soft-Isomorphic Relational Graph Convolution Network (SIR-GCN)
    
    .. math::
        h_u^* = \sum_{v \in \mathcal{N}(u)} W_R \sigma(W_Q h_u + W_K h_v)

    Parameters
    ----------
    input_dim : int
        Input feature dimension
    hidden_dim : int
        Hidden feature dimension
    output_dim : int
        Output feature dimension
    activation : a callable layer
        Activation function, the :math:`\sigma` in the formula
    dropout : float, optional
        Dropout rate for inner linear transformations, defaults to 0
    inner_bias : bool, optional
        Whether to learn an additive bias for inner linear transformations, defaults to True
    outer_bias : bool, optional
        Whether to learn an additive bias for outer linear transformations, defaults to False
    agg_type : str, optional
        Aggregator type to use (``sum``, ``max``, ``mean``, or ``sym``), defaults to ``sum``
    """
    def __init__(self, input_dim, hidden_dim, output_dim, activation, dropout=0, inner_bias=True, outer_bias=False, agg_type='sum'):
        super(SIRConv, self).__init__()
        self.activation = activation
        self.dropout = nn.Dropout(dropout)
        self.linear_query = nn.Linear(input_dim, hidden_dim, bias=inner_bias)
        self.linear_key = nn.Linear(input_dim, hidden_dim, bias=False)
        self.linear_relation = nn.Linear(hidden_dim, output_dim, bias=outer_bias)

        self._agg_type = agg_type
        self._agg_func = fn.sum if agg_type == 'sym' else getattr(fn, agg_type)
    
    def message_func(self, edges):
        if self._agg_type in ['sum', 'mean', 'sym']:
            return {'m': edges.src['norm'] * edges.dst['norm'] * self.activation(edges.dst['eq'] + edges.src['ek'])}
        else: # Cannot leverage linearity for max aggregation
            return {'m': self.linear_relation(self.activation(edges.dst['eq'] + edges.src['ek']))}
        
    def forward(self, graph, feat):
        with graph.local_scope():
            degs = graph.in_degrees().float().clamp(min=1).to(graph.device)
            norm = torch.pow(degs, -0.5) if self._agg_type == 'sym' else torch.ones(graph.num_nodes(), device=graph.device)
            norm = norm.reshape((graph.num_nodes(),) + (1,) * (feat.dim() - 1))
            graph.ndata['norm'] = norm
 
            feat_key, feat_query = expand_as_pair(feat, graph)
            graph.ndata['ek'] = self.dropout(self.linear_key(feat_key))
            graph.ndata['eq'] = self.dropout(self.linear_query(feat_query))

            graph.update_all(self.message_func, self._agg_func('m', 'ft'))
            rst = graph.ndata.pop('ft')
            rst = self.linear_relation(rst) if self._agg_type in ['sum', 'mean', 'sym'] else rst
            
            return rst

SIR-GCN implementation in DGL.

Experimental Results

To highlight the contribution, we evaluated SIR-GCN on a subset of the Benchmarking GNNs, a collection of benchmark datasets consisting of diverse mathematical and real-world graphs across various GNN tasks. In particular, the WikiCS, PATTERN, and CLUSTER datasets fall under node property prediction tasks while the MNIST, CIFAR10, and ZINC datasets fall under graph property prediction tasks. Furthermore, the WikiCS, MNIST, and CIFAR10 datasets have uncountable node features while the remaining datasets have countable node features. The performance metric of ZINC is the mean absolute error (MAE) while the performance metric of the remaining datasets is accuracy. See Dwivedi et al. (2023) for more information regarding the individual datasets.

ModelWikiCS (↑)PATTERN (↑)CLUSTER (↑)MNIST (↑)CIFAR10 (↑)ZINC (↓)
MLP59.45 ± 2.3350.52 ± 0.0020.97 ± 0.0095.34 ± 0.1456.34 ± 0.180.706 ± 0.006
GCN77.47 ± 0.8585.50 ± 0.0547.83 ± 1.5190.12 ± 0.1554.14 ± 0.390.416 ± 0.006
GraphSAGE74.77 ± 0.9550.52 ± 0.0050.45 ± 0.1597.31 ± 0.1065.77 ± 0.310.468 ± 0.003
GAT76.91 ± 0.8275.82 ± 1.8257.73 ± 0.3295.54 ± 0.2164.22 ± 0.460.475 ± 0.007
GIN75.86 ± 0.5885.59 ± 0.0158.38 ± 0.2496.49 ± 0.2555.26 ± 1.530.387 ± 0.015
PNA---97.19 ± 0.0870.21 ± 0.150.320 ± 0.032
SIR-GCN78.06 ± 0.6685.75 ± 0.0363.35 ± 0.1997.90 ± 0.0871.98 ± 0.400.278 ± 0.024

Test performance on Benchmarking GNNs.

The table above presents the mean and standard deviation of the test performance for SIR-GCN and comparable GNNs across the six benchmarks where the experimental set-up follows that of Dwivedi et al. to ensure a fair evaluation. The results show that SIR-GCN consistently outperforms key GNNs in literature. Notably, SIR-GCN also outperforms PNA which uses multiple aggregators. This supports the intuition of utilizing anisotropic and dynamic messages to enhance the expressivity of GNNs.

Conclusion

Overall, we proposed a novel perspective for creating a powerful GNN when the space of node features is uncountable. Our theoretical results suggest the need for contextualized message functions, which allow the model to learn the relationship between the features of the center and neighbor nodes.

Let me know your thoughts!