As a note beforehand, this guide is not meant as a comprehensive review or in-depth tutorial on GNNs; rather, it is meant to build intuition for what is happening under the hood of simple GNNs. Our goal by the end will be to have the ability to point at any operation inside the GNN and explain what it is doing, and what are the shapes and meaning of all the tensors and neural network weights involved. A follow-up blog post will relate the pseudocode shown at the end to real Python and Pytorch Geometric code.

Data in the world often comes associated with some sort of underlying structure. For example, images come with a 2D grid structure, which allows us to group and analyze pixels within local regions together. We can make assumptions about the data and build these into our neural network architectures in the form of **inductive biases**, which helps the model learn and generalize on the data. Weight sharing and spatial locality in Convolutional Neural Networks (CNNs) are great examples of this.

Oftentimes, however, the data is structured in a more varied way, with entities connected to one another by relationships in real life. For example, humans are connected to one another in social networks through friendship connections and online interactions, which might be represented as a graph by defining each user as a **node** connected to other users by **edges** which represent online connections or interactions through posts. Molecular graphs connect atoms to other atoms through different types of chemical bonds, and a road network might define different cities as nodes which are connected to one another by a web of roads. Any one **“node”** entity in these graphs may be connected to any number of other entities through **edge** connections, which means that any neural network we design to learn on this graph-structured data will need to have a very generalized aggregation scheme to effectively integrate information from nodes and their surrounding neighbors. What is the benefit of representing this real-world data as a graph, rather than some other conventional data format? It allows us to flexibly model relationships between any number of entities connected by any number of edges, without having to simplify or project our data into a simpler format.

Furthermore, the entities and relationships we define in our graphs can capture more complexities which we find in real-world data. In social media platforms, for example, think about individual users and companies both being considered users on the platform, who can write posts and be a part of subcommunities. All of these can be defined as separate types of nodes and edges, with different associated feature attributes. We can even have multi-hop relationships (e.g. a friend of a friend), which can make for some fascinating modeling challenges! We’ll leave that for another post, and stick to basic **homogeneous graphs** for now, where we deal with only one type of entity.

We’ve seen examples of graph-structured data, however we need a principled way of representing the feature attributes and connectivity information of a graph in matrices, so that we can do operations on them and learn from data using neural networks. Let’s define a few matrices which will tell us how we hold the graph data, namely the **node feature matrix** and the **adjacency matrix**:

For now, let’s keep looking at a small molecular graph from earlier, made up of six blue and green atoms numbered 1 to 6. We have two matrices which hold all of the information we need to describe a simple graph, so let’s take a closer look and understand what is in each matrix.

- The node feature matrix is a matrix which contains all of the features for all nodes in our graph. The shape of this matrix will be [number_of_nodes x number of features], which is [6 x 4] in our small example above, and is usually denoted as \(X\). With \(N=6\) nodes and \(F=4\) features, we have \(X \in R^{N \times F}\). You can imagine that the four features might be attributes of each atom, such as its atomic number, atomic mass, charge, and other relevant attributes.

- The adjacency matrix is a (usually) binary matrix which contains information about what nodes are connected to what other nodes in the graph. This helps us keep track of connections, which we will need once we define a neural network architecture to aggregate information from the surrounding neighbors of each node. Information aggregation in graphs is useful because learning on graphs involves both understanding nodes as well as how they interact with and are similar to their neighboring nodes.
- The shape of the adjacency matrix will be [number_of_nodes x number_of_nodes], which will be [6 x 6] in our small example and is usually denoted as \(A \in R^{N \times N}\). Edges usually have some directionality (a “source” node and “destination” node), so by convention we say that source nodes are the rows and destination nodes are the columns of the matrix, with a 1 indicating an edge between source node \(u\) and destination node \(v\).
- You’ll notice that the diagonal of the adjacency matrix are all 1s, and are highlighted in green. We have a choice in modeling our graph of whether we want to consider a node as connected to itself or not (it may or may not make a difference depending on our data and GNN architecture). For cases where a node’s features or state affects its own state in the future (i.e. an atom’s embedding should reflect the atom’s identity along with other atoms it is connected to), it is generally good to include self-loops. For this simple example, we will include self-connections to connect atoms to themselves.
- You will also notice that the adjacency matrix is symmetric around its diagonal; this means we are working on an undirected graph (atom 1 being connected to atom 2 means 2 is connect to 1 as well). This is not always the case, for example, think about a citation networks: paper A citing paper B does not mean the reverse is true.

With these two matrices, we have everything we need to numerically describe our graph-structured data. The node feature matrix \(X\) can be seen as initial/input node features, and our goal for learning on graphs will be to learn node embeddings \(H \in R^{N \times D}\), where \(D\) is some hidden dimension which we choose, which meaningfully represent each node for downstream tasks based on both the node’s input features and the neighboring nodes it was connected to. Downstream tasks may include **node-level** tasks such as classifying what type of atom each node is, **edge-level** tasks such as classifying what bond type two atoms should have between one another, and **graph-level** tasks such as predicting whether the molecule as a whole is toxic or not. You can imagine how, depending on the task, it is important for each atom to integrate information from neighboring atoms and have an overall picture of where it is in relation to the whole molecule.

Now that we’ve seen our data and represented it using node feature and adjacency matrices, let’s get into actually learning on graph-structured data. Because graph data varies in both number of nodes and edge connections between nodes, we need a neural network architecture which can operate on arbitrary node entities with variable number of neighbors while producing meaningful node embeddings for our task. On images, we usually perform information aggregation by taking advantage of spatial locality in images, convolving over groups of pixels to form higher-level abstract features. On graphs, however, we are going to define a **graph convolution**, which aggregates information from a node and all of its neighbors, and updates that node’s learned embedding in a message-passing step.

Many GNN architectures have been proposed with varying forms of graph convolutions, and several of the simple, classic GNNs are still used (Graph Convolutional Networks (GCNs) [1], GraphSAGE [2], and Graph Attention Networks [3], to name a few). When learning about GNNs, however, it can be helpful to first start with thinking simply about **message-passing neural networks (MPNNs)**, which is an abstraction of GNN frameworks for learning on graphs proposed in [4]. MPNNs are a general framework where nodes pass messages to one another along edges in the graph in three defined steps:

**Message:**every node decides how to send information to neighboring nodes it is connected to by edges**Aggregate:**nodes receive messages from all of their neighbors, who also passed messages, and decides how to combine the information from all of its neighbors**Update:**each node decides how to combine neighborhood information with its own information, and updates its embedding for the next timestep

If we can define these three operations, then we can have all nodes pass each other information in what is considered one message passing step, which disseminates information around the graph a bit. This can be repeated for \(K\) iterations, and the more times we pass information around (larger \(K\)), the more we diffuse information around the graph, which affects the embeddings we get at the end. One way I like to think about this is a group of people spaced 1 step apart from each other, iteratively telling those next to them their name + any other names they have heard from their neighbors. After K rounds of name-telling (information-passing), any one person will have heard the name of all people within K steps of them at least once.

Finally, if we incorporate some learned weights from a neural network into our message-passing operations and define a loss function on the resulting embeddings for some downstream task (e.g. node classification), then we have all of the ingredients we need for learning on graph-structured data.

Let’s zoom in a bit on each step for one destination node \(v\), define some notation, and visualize how the node feature matrix and adjacency matrix are going into each operation:

**Message:**source node \(u\) will pass message \(m_{uv}\) to destination node \(v\), which is node 2 in our small example.- What exactly is the message? It depends on the GNN architecture! For simplicity, we will go with the easiest message node \(u\) can give to node \(v\), which is just passing its node feature \(h_u\) vector to \(v\). More complex GNNs might do some learned operations to come up with a better message.

**Aggregate:**we can choose some aggregation function to combine information from neighboring nodes, such as SUM or MEAN, which works across any number of neighboring nodes. This gives us a combined neighborhood node embedding denoted as \(h_{N(v)}\), where \(N(v)\) denotes the neighborhood of destination node \(v\), meaning all nodes connected to node \(v\).- \[h_{N(v)}^{k+1} = AGGREGATE({h_u^k, \forall u \in N(v)})\]
- Note: a special note about the aggregate operation is that we usually need to choose a permutation-invariant function to aggregate neighboring node messages. This because neighboring nodes don’t have an ordering with respect to the destination node, so our aggregate function needs to give the same output no matter the ordering of the inputs.

**Update:**we can concatenate the neighborhood embedding \(h_{N(v)}^{k+1}\) with the embedding of the node itself, \(h_v^k\), and parameterize it with some learned weights \(W\) and a nonlinearity \(\sigma\) to form our final update step:- \[h_v^{k+1} = \sigma(W \cdot CONCAT(h_v^k, h_{N(v)}^{k+1}))\]

And now we’ve done it! We’ve made it through one message passing step, and if we repeat this for all destination nodes v, then we have our updated node embeddings for the next timestep \(k+1\).

The GraphSAGE paper [2] introduces a pseudocode algorithm for message passing which I quite like, and will put below for those thinking about the overall algorithm. This is actually the first algorithm I dissected as an undergraduate student to understand each operation and relate it to code implementations (which I will do in another blog post!).

It is quite a powerful algorithm when you think about it: in one code block, containing 10 lines, we can define a sequence of operations that encompasses how all MPNNs operate on arbitrary graph-structured data, and can become arbitrarily complex depending on how you define each of the three core operations: **message**, **aggregate**, and **update**.

The nice thing about thinking through the message-passing framework is that we can recover many classical GNN architectures depending on the choice of message, aggregate, and update operations. Here are a few examples I like to think of (simplifying a bit for the sake of explanation):

- If we choose our permutation-invariant aggregator to be a simple averaging, and include self-connections in our adjacency matrix, we can recover the original GCN architecture [1]. The GCN formulation defines this as a matrix multiplication: \(\tilde{A}XW\), which does the aggregation through matrix multiplication with a normalized adjacency matrix \(\tilde{A}\).
- In the message step, what if we consider how much the source node is important to the destination node, and assign a score for that edge? We could weigh the edges with these scores if we normalize them correctly, for example by making all incoming edge scores sum up to 1. Then, our aggregation is a weighted aggregation, which is more informative than assuming all neighboring nodes have the same importance. This is the main idea behind GATs [3].

**Final note:** thank you for reading through to the end of this blog post! I appreciate your attention, and hope these ideas are useful to you in your work or studies as much as it was useful for me when I began studying GNNs. As this is my first blog post, I’d greatly appreciate any comments/tips/suggestions! The best place to reach me is at my email: syed [dot] rizvi [at] yale [dot] edu.

- Kipf, Thomas N., and Max Welling. “Semi-supervised classification with graph convolutional networks.” arXiv preprint arXiv:1609.02907 (2016).
- Hamilton, Will, Zhitao Ying, and Jure Leskovec. “Inductive representation learning on large graphs.” Advances in neural information processing systems 30 (2017).
- Veličković, Petar, et al. “Graph attention networks.” arXiv preprint arXiv:1710.10903 (2017).
- Gilmer, Justin, et al. “Neural message passing for quantum chemistry.” International conference on machine learning. PMLR, 2017.

As a quick note: since we are writing our code from scratch for understanding, it will end up being a bit more verbose than necessary. In practice, many operations are abstracted away, hidden under-the-hood by GNN libraries like Pytorch Geometric (PyG) [1], allowing us to just focus on the details of our GNN and data which we care about. Our code will not use any data structures or layers from PyG for simplicity, so that only an understanding of Pytorch and Python class-based definitions is necessary to read the code snippets.

The goal by the end of this coding tutorial is to feel comfortable looking at code implementations of PyG-style message-passing. Afterwards, the jump to looking at real source code of different GNNs in PyG (docs) will feel easier, which will enable you to read more GNN papers and make more connections to real code! If you end up using an alternative GNN library other than PyG in your research (or an alternative library to Pytorch, or another programming language altogether!), don’t worry, the general idea of message-passing operations should carry over sufficiently well to other implementations.

To start out, let’s revisit the molecular graph example which we saw in the previous post:

We will build our graph input example from this example, for familiarity. The graph will consist of 6 nodes representing hydrogen (blue) and carbon (C) atoms, each with four features: atomic number, atomic mass, (made-up) charge values, and number of incoming edges. We also have the same adjacency matrix from before, with 18 edges in total connecting nodes together, including self-edges.

We can initialize our node feature matrix and adjacency matrix as Pytorch tensors in our Python code as follows:

```
# Define input node feature matrix and adjacency matrix
input_node_feature_matrix = torch.tensor([
[1.0, 1.0078, 1, 2], # atomic number, atomic mass, charge, and number of bonds
[1.0, 1.0078, 1, 4],
[6.0, 12.011, -1, 3],
[1.0, 1.0078, 0, 4],
[6.0, 12.011, -1, 3],
[1.0, 1.0078, 1, 2],
], dtype=torch.float32) # [num_nodes, num_features]
binary_adjacency_matrix = torch.tensor([
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 1, 0],
[0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 1],
[0, 1, 0, 1, 1, 0],
[0, 0, 0, 1, 0, 1],
], dtype=torch.int64) # [num_nodes, num_nodes]
```

These two Pytorch tensors contain all of the information we need about node features as well as graph connectivity. You’ll notice, however, that the adjacency matrix contains many zero values, which gets worse as we scale to much larger graphs (millions and billions of nodes!) since nodes tend to be connected to only a few other nodes.

Because of this, we often opt for an edge list representation of graph connectivity information, where instead of a [num_nodes, num_nodes] matrix containing many zeros for missing edges, we transform it into an edge list of shape [2, num_edges], which specifies the index of the start (source) and end (destination) node for each existing edges. With this representation, we only store two pieces of information for edges that actually exist, rather than 1 piece of information for every possible edge that might exist in the graph. Libraries such as PyG opt for this edge list representation, which they call an **edge_index**, so we will define a conversion function ourselves to turn an adjacency matrix into an edge_index tensor as follows:

```
def adj_matrix_to_sparse_edge_index(adj_matr: torch.Tensor):
"""
This function takes a square binary adjacency matrix, and returns an edge list representation
containing source and destination node indices for each edge.
Arguments:
adj_matr: torch Tensor of adjacency information, shape [num_nodes, num_nodes], dtype torch.int64
Returns:
edge_index: torch Tensor of shape [2, num_edges], dtype torch.int64
"""
src_list = []
dst_list = []
for row_idx in range(adj_matr.shape[0]):
for col_idx in range(adj_matr.shape[1]):
if adj_matr[row_idx, col_idx].item() > 0.0:
src_list.append(row_idx)
dst_list.append(col_idx)
return torch.tensor([src_list, dst_list], dtype=torch.int64) # [2, num_edges]
edge_index = adj_matrix_to_sparse_edge_index(binary_adjacency_matrix) # [2, num_edges]
```

Now that we have our input node feature and edge_index tensors, we can move on to defining our message-passing layer which will implement the GraphSAGE message-passing algorithm. If we look at the pseudocode at the beginning, we can see that in the main message-passing logic happens in two lines of pseudocode, which happen for each node in the graph:

- \[h_{N(v)}^{k+1} = AGGREGATE({h_u^k, \forall u \in N(v)})\]
- \[h_v^{k+1} = \sigma(W \cdot CONCAT(h_v^k, h_{N(v)}^{k+1}))\]

These two lines of code define mathematically how we will do message passing, specifying the steps of message-passing which we previously covered: assuming (1) messages have been created, we (2) aggregate messages from neighboring nodes to get \(h_{N(v)}^{k+1}\), and (3) update representations, ending up with \(h_v^{k+1}\).

To implement this in code, we will need an organized definition of the message, aggregate, and update steps. In Pytorch-style coding, neural network layers are typically defined in Python class syntax, where we define a Python class which will house our GNN layer:

```
class GraphSAGELayer(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
# Define linear layers parameterizing neighbors and self-loops
self.lin_neighbors = nn.Linear(in_dim, out_dim, bias=True)
self.lin_self = nn.Linear(in_dim, out_dim, bias=True)
```

Here we have defined a Python class called GraphSAGELayer, which inherits from Pytorch’s torch.nn.Module class. This Module class lets us inherit functionalities for neural network that will allow us to train our model using stochastic gradient descent, along with all of Pytorch’s other functionalities.

Looking again at the two pseudocode lines, we can see that the only learnable parameters in GraphSAGE is a weight matrix \(W\), which parameterizes a concatenation of a node’s own embedding \(h_v^k\) with its neighborhood embedding \(h_{N(v)}^{k+1}\). In practice, GraphSAGE is implemented in Pytorch Geometric (here) using two linear layers for neighboring message embeddings and a node’s self embedding. The reason for this might be twofold: (i) having two separate layers leaves an option to not have a self-embedding weight, which can be desirable sometimes, and (ii) sometimes we may want separate weights parameterizing self-connections, which can be seen as a form of skip-connections for GNN embeddings.

In our code, we will call these two linear layers **lin_neighbors** and **lin_self**, as shown above. Now comes an important part: how do we implement logic to create and pass messages, and aggregate embeddings for neighbors in order to obtain \(h_{N(v)}^{k+1}\)? We can define our own function for message-passing as follows:

```
def message_passing(self, x: torch.Tensor, edge_index: torch.Tensor):
"""
This function is responsible for passing messages between nodes according to the edges
in 'edge_index'.
- Messages from the source --> destination node consist of the source nodes feature vector.
- Sum aggregation is used to aggregate incoming messages from neighbors.
Arguments:
x: torch Tensor of node representations, shape [num_nodes, hidden_size], dtype torch.float32
edge_index: torch Tensor of graph connectivity information, shape [2, num_edges], dtype torch.int64
Returns:
neigh_embeds: torch Tensor of aggregated neighbor embeddings, shape [num_nodes, hidden_size], dtype torch.float32
"""
src_node_indices = edge_index[0, :] # shape [num_edges]
dst_node_indices = edge_index[1, :] # shape [num_edges]
# Step (1): Message
src_node_feats = x[src_node_indices] # shape [num_edges, hidden_size]
# Mean aggregation
neighbor_sum_aggregations = []
for dst_node_idx in range(x.shape[0]): # loop over destination nodes
# find incoming edges, get incoming messages from source nodes
incoming_edge_indices = torch.where(dst_node_indices == dst_node_idx)[0] # find incoming edges
incoming_messages = src_node_feats[incoming_edge_indices] # shape [num_incoming_edges, hidden_size]
# Step (2): Aggregate - sum messages from neighbors (if > 1 neighbors)
incoming_messages_summed = incoming_messages.sum(dim=0) if incoming_messages.shape[0] > 1 else incoming_messages
neighbor_sum_aggregations.append(incoming_messages_summed)
neigh_embeds = torch.stack(neighbor_sum_aggregations) # [num_nodes, hidden_size]
return neigh_embeds
```

This function is quite involved, so we will go through step-by-step, and point out where code connects to pseudocode with comments. We can see that the inputs to this function are our node feature matrix **x**, and the **edge_index**, which we already have from earlier. The function definition states that we will take these two tensors as input, and we will eventually return \(h_{N(v)}^{k+1}\).

The first thing we need to do is organize how our nodes are going to pass messages to each other, for Step (1): Message. The simplest message which one node can pass to another node is its node embedding (which is also the case in GraphSAGE), so we first get the indices of our source nodes from our **edge_index**, and use that to index into our node feature matrix **x**. If you are familiar with array indexing in Pytorch, you’ll realize that this gives us a [num_edges, hidden_size] tensor, effectively giving us a tensor containing source node embeddings. This is an important step, because with the leading dimension being *num_edges* rather than *num_nodes*, we can do edge operations and deal with passing messages along edges.

With this indexing operation, our first step of message creation is already complete, since we are using source node embeddings as the message to be passed. Now, we need to perform the next step, which is to aggregate embeddings for each destination node using a permutation-invariant aggregator. We will use sum aggregation here, since it is a more expressive aggregation function (more on that another time!), which means for each destination node in the graph, we need to sum all incoming message embeddings. We accomplish this by looping over destination nodes, finding which edges are ending at that destination node, and summing the corresponding messages. The resulting variable, **neigh_embeds**, directly corresponds to \(h_{N(v)}^{k+1}\) in the pseudocode.

We can complete our message-passing layer implementation by writing a forward() function, which tells Pytorch how we want a forward pass through our neural network to be implemented:

```
def forward(self, x: torch.Tensor, edge_index: torch.Tensor):
"""
Implementation for one message-passing iteration for GraphSAGE.
Arguments:
x: torch Tensor of node representations, shape [num_nodes, hidden_size], dtype torch.float32
edge_index: torch Tensor of graph connectivity information, shape [2, num_edges], dtype torch.int64
Returns:
out: torch Tensor of updated node representations, shape [num_nodes, hidden_size], dtype torch.float32
"""
x_message_passing, x_self = x, x # duplicate variables pointing to node features
neigh_embeds = self.message_passing(x_message_passing, edge_index)
neigh_embeds = self.lin_neighbors(neigh_embeds)
x_self = self.lin_self(x_self)
# # Step (3): Update - sum concatenation to update node representations
out = neigh_embeds + x_self
return out
```

With the message_passing() function doing the heavy lifting, all we need to do in the forward() function is call the function message passing, and then perform step 3, which is updating node representations. This is done by running \(h_{N(v)}^{k+1}\) and \(h_v^k\) through their respective linear layers, and then concatenating them together. In practice, concatenation operations are done either through summing vectors together, or by joining two vectors together (resulting in a longer vector). I have not seen a preference for either method for concatenation in code implementations thus far.

Now that we have defined a full message-passing class using simple operations, we can complete a full 1-layer GNN model by defining a second class which will use our just-completed message passing layer definition:

```
class GraphSAGEModel(nn.Module):
def __init__(self, in_features: int, hidden_size: int, out_features: int, dropout: int = 0.1):
super().__init__()
self.input_proj = nn.Linear(in_features, hidden_size, bias=True)
self.conv1 = GraphSAGELayer(in_dim=hidden_size, out_dim=hidden_size)
self.act1 = nn.ReLU()
self.drop1 = nn.Dropout(p=dropout)
self.lin_out = nn.Linear(hidden_size, out_features, bias=True)
def forward(self, x: torch.Tensor, edge_index: torch.Tensor):
"""
Forward pass implementation of 1-layer GraphSAGE model.
Arguments:
x: torch Tensor of input node features, shape [num_nodes, num_features], dtype torch.float32
edge_index: torch Tensor of graph connectivity information, shape [2, num_edges], dtype torch.int64
"""
x = self.input_proj(x) # Input projection: [num_nodes, num_features] --> [num_nodes, hidden_size]
x = self.conv1(x, edge_index) # Message-passing
x = self.act1(x)
x = self.drop1(x)
x = self.lin_out(x)
return F.log_softmax(x, dim=-1) # softmax over last dim for classification
```

This class again inherits from nn.Module, and it defines 1 layer of message-passing by calling the GraphSAGELayer() class we just defined above. It also defines several other components, such as a ReLU nonlinearity after the message-passing layer, a dropout layer, and input/output projections. This definition is for a classification model with 1-message passing layer; if we wanted to change the task the model is built for, we could change the output head and remove the final softmax layer as we need depending on our task. If we have a need to pass messages multiple times, we can simply define more layers of our GraphSAGELayer class to pass messages more times! Note that this would mean not sharing weights for different message-passing iterations, which is common practice.

We can now put everything together by defining an instance of our 1-layer GraphSAGE model and doing a full forward pass on our example graph! The full code is below, and is also available on GitHub:

```
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
class GraphSAGELayer(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
# Define linear layers parameterizing neighbors and self-loops
self.lin_neighbors = nn.Linear(in_dim, out_dim, bias=True)
self.lin_self = nn.Linear(in_dim, out_dim, bias=True)
def message_passing(self, x: torch.Tensor, edge_index: torch.Tensor):
"""
This function is responsible for passing messages between nodes according to the edges
in 'edge_index'.
- Messages from the source --> destination node consist of the source nodes feature vector.
- Sum aggregation is used to aggregate incoming messages from neighbors.
Arguments:
x: torch Tensor of node representations, shape [num_nodes, hidden_size], dtype torch.float32
edge_index: torch Tensor of graph connectivity information, shape [2, num_edges], dtype torch.int64
Returns:
neigh_embeds: torch Tensor of aggregated neighbor embeddings, shape [num_nodes, hidden_size], dtype torch.float32
"""
src_node_indices = edge_index[0, :] # shape [num_edges]
dst_node_indices = edge_index[1, :] # shape [num_edges]
# Step (1): Message
src_node_feats = x[src_node_indices] # shape [num_edges, hidden_size]
# Mean aggregation
neighbor_sum_aggregations = []
for dst_node_idx in range(x.shape[0]): # loop over destination nodes
# find incoming edges, get incoming messages from source nodes
incoming_edge_indices = torch.where(dst_node_indices == dst_node_idx)[0] # find incoming edges
incoming_messages = src_node_feats[incoming_edge_indices] # shape [num_incoming_edges, hidden_size]
# Step (2): Aggregate - sum messages from neighbors (if > 1 neighbors)
incoming_messages_summed = incoming_messages.sum(dim=0) if incoming_messages.shape[0] > 1 else incoming_messages
neighbor_sum_aggregations.append(incoming_messages_summed)
neigh_embeds = torch.stack(neighbor_sum_aggregations) # [num_nodes, hidden_size]
return neigh_embeds
def forward(self, x: torch.Tensor, edge_index: torch.Tensor):
"""
Implementation for one message-passing iteration for GraphSAGE.
Arguments:
x: torch Tensor of node representations, shape [num_nodes, hidden_size], dtype torch.float32
edge_index: torch Tensor of graph connectivity information, shape [2, num_edges], dtype torch.int64
Returns:
out: torch Tensor of updated node representations, shape [num_nodes, hidden_size], dtype torch.float32
"""
x_message_passing, x_self = x, x # duplicate variables pointing to node features
neigh_embeds = self.message_passing(x_message_passing, edge_index)
neigh_embeds = self.lin_neighbors(neigh_embeds)
x_self = self.lin_self(x_self)
# # Step (3): Update - sum concatenation to update node representations
out = neigh_embeds + x_self
return out
class GraphSAGEModel(nn.Module):
def __init__(self, in_features: int, hidden_size: int, out_features: int, dropout: int = 0.1):
super().__init__()
self.input_proj = nn.Linear(in_features, hidden_size, bias=True)
self.conv1 = GraphSAGELayer(in_dim=hidden_size, out_dim=hidden_size)
self.act1 = nn.ReLU()
self.drop1 = nn.Dropout(p=dropout)
self.lin_out = nn.Linear(hidden_size, out_features, bias=True)
def forward(self, x: torch.Tensor, edge_index: torch.Tensor):
"""
Forward pass implementation of 1-layer GraphSAGE model.
Arguments:
x: torch Tensor of input node features, shape [num_nodes, num_features], dtype torch.float32
edge_index: torch Tensor of graph connectivity information, shape [2, num_edges], dtype torch.int64
"""
x = self.input_proj(x) # Input projection: [num_nodes, num_features] --> [num_nodes, hidden_size]
x = self.conv1(x, edge_index) # Message-passing
x = self.act1(x)
x = self.drop1(x)
x = self.lin_out(x)
return F.log_softmax(x, dim=-1) # softmax over last dim for classification
def adj_matrix_to_sparse_edge_index(adj_matr: torch.Tensor):
"""
This function takes a square binary adjacency matrix, and returns an edge list representation
containing source and destination node indices for each edge.
Arguments:
adj_matr: torch Tensor of adjacency information, shape [num_nodes, num_nodes], dtype torch.int64
Returns:
edge_index: torch Tensor of shape [2, num_edges], dtype torch.int64
"""
src_list = []
dst_list = []
for row_idx in range(adj_matr.shape[0]):
for col_idx in range(adj_matr.shape[1]):
if adj_matr[row_idx, col_idx].item() > 0.0:
src_list.append(row_idx)
dst_list.append(col_idx)
return torch.tensor([src_list, dst_list], dtype=torch.int64) # [2, num_edges]
if __name__ == "__main__":
# Define input node feature matrix and adjacency matrix
input_node_feature_matrix = torch.tensor([
[1.0, 1.0078, 1, 2], # atomic number, atomic mass, charge, and number of bonds
[1.0, 1.0078, 1, 4],
[6.0, 12.011, -1, 3],
[1.0, 1.0078, 0, 4],
[6.0, 12.011, -1, 3],
[1.0, 1.0078, 1, 2],
], dtype=torch.float32)
binary_adjacency_matrix = torch.tensor([
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 1, 0],
[0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 1],
[0, 1, 0, 1, 1, 0],
[0, 0, 0, 1, 0, 1],
], dtype=torch.int64)
edge_index = adj_matrix_to_sparse_edge_index(binary_adjacency_matrix)
print("input_node_feature_matrix:", input_node_feature_matrix.shape)
print(input_node_feature_matrix)
print("binary_adjacency_matrix:", binary_adjacency_matrix.shape)
print(binary_adjacency_matrix)
print("edge_index:", edge_index.shape)
print(edge_index, "\n")
# Define GraphSAGE model
model = GraphSAGEModel(
in_features=4, # 4 input features per node
hidden_size=16, # 16-dimensional latent vectors
out_features=2 # 2 classes of nodes in our example: Carbon and Hydrogen
)
print("\nModel:")
print(model, "\n")
# Forward pass & loss calculation for node classification
output = model(x=input_node_feature_matrix, edge_index=edge_index)
atom_labels = torch.tensor([0, 0, 1, 0, 1, 0], dtype=torch.int64) # 0 = Hydrogen, 1 = Carbon
loss = F.nll_loss(output, target=atom_labels)
print("Loss value: {:.5f}".format(loss.item()))
```

I hope this code tutorial was useful for you! Many of these operations are abstracted away under the hood of GNN libraries, however understanding the underlying operations going on during message-passing the first step to being able to adapt and improve the algorithm as per your needs and goals. If the code snippets make sense, and you succeed in running them and looking at the printed outputs, I would highly enourage you to look at real source code for GNNs in Pytorch Geometric, for instance the GraphSAGE implementation. You will notice that PyG exposes certain functions, such as **message()**, allowing developers to override these functions to inject custom behavior during message-passing. It is a clever software engineering design that allows developers to build custom GNN models which still abstracting low-level operations away from us, like aggregating neighboring nodes based on edges.

As always, feedback is welcome and appreciated on this code tutorial at: syed [dot] rizvi [at] yale [dot] edu

- Fey, Matthias, and Jan Eric Lenssen. “Fast graph representation learning with PyTorch Geometric.” arXiv preprint arXiv:1903.02428 (2019).