Notebook 5: Use GraphStorm APIs for a Customized Model to Perform Graph-level Prediction

Graph-level prediction, such as graph classification or graph regression, is a common task in Graph Machine Learning (GML) across various domains, including life sciences and chemistry. In graph-level prediction, the entire graph data is typically organized as a batch of subgraphs, where each subgraph’s nodes have edges only within the subgraph and no edges connecting to nodes in other subgraphs. GML labels are linked to these subgraphs. Once trained, GML models can make predictions on new and unseen subgraphs.

A typical operation used in graph prediction is called Read-out, which aggregates the representations of nodes in a subgraph to form one representation for the subgraph. The outputs of the Read-out can then be used to make predictions downstream, acting as a single representation of the entire subgraph.

1f6a436609b64d119024672f84a70549

The current version of GraphStorm can not directly perform graph prediction. But as GraphStorm supports node-level prediction, we can use a method called supernode to perform graph-level predictions.


super-node Method Explanation

Instead of using the Read-out operation, we add a new node, called super node, to each subgraph, and link all original nodes of the subgraph to it, without adding reversed edges. With these inbound edges, representations of all original nodes in a subgraph could be easily aggregated to the super node. We can then use the super node as the representation of this subgraph to perform graph-level prediction tasks. The super-node method helps us to turn a graph prediction task into a node prediction task.

d18af8477adc48ce97d9b69204939044


Implementation Ideas

In order to use the super-node method in GraphStorm, we need to implement two extra functions.

  • Raw Graph Data Conversion: Add a super node to each subgraph in the original batch of subgraphs, and store all of them as one heterogeneous graph ready for GraphStorm’s graph construction CLIs. After the graph is converted into GraphStorm’s distributed graph format, we can use all of GraphStorm’s built-in GNN models to perform the super-node prediction.

  • Customized GNN Encoder (Optional): Create a specialized GNN encoder to aggregate super node representations. This is an optional function as all built-in GraphStorm GNN encoders can aggregate and generate embeddings for super nodes naturally. But creating a customized GNN encoder can provide fine-grained control of aggregation methods, which can mimic the Read-out method.

This notebook will demonstrate the super-node method by using GraphStorm APIs and other libraries to implement both functions. This notebook serves as an example of a Graph Classification Solution using GraphStorm APIs. Users can modify the custom GNN model and implement their own version.


Prerequisites

This notebook assumes the following:


1. Raw Graph Data Conversion

Converting the raw graph prediction dataset into super-node format for GraphStorm could be illustrated as the diagram below.

f5f1664e629e4a2193c2d97d26a3b149

In this notebook, we use the OGBG Molhiv Data, which is a popular molecular property, graph-level prediction dataset. In the interest of space, we will not show the actual raw graph data conversion code in this notebook. Users can find the source code of OGBG data conversion in GraphStorm’s graph prediction example folder.

Tip: We also provide a Python script to generate synthetic supernode-based graph data for users to better understand the super-node graph data format, which is available here.

We can download the source code of the OGBG conversion and generate the super-node format OGBG data with the following commands.

[1]:
!wget -c https://raw.githubusercontent.com/awslabs/graphstorm/main/examples/graph_prediction/gen_ogbg_supernode.py

!python gen_ogbg_supernode.py --ogbg-data-name molhiv --output-path ./supernode_raw/

The converted OGBG data will be stored at ./supernode_raw/. Then we can run GraphStorm’s GConstruct command to partition the graph for model training and inference. The processed graph is stored in the /supernode_gs_1p/ folder.

[2]:
!python -m graphstorm.gconstruct.construct_graph \
        --conf-file ./supernode_raw/config.json \
        --output-dir ./supernode_gs_1p/ \
        --num-parts 1 \
        --graph-name supernode_molhiv

2. Customized GNN Encoder for Graph Prediction

The key component of this super-node based solution is the GNN model that can perform message passing and aggregation in each subgraph, and then perform a sort of Read-out operation in the super nodes. This component could be easily implemented as a customized GraphStorm GNN Encoder like demonstrated in Notebook 4: Customized Models.

As shown in the diagram below, a super node will aggregate the representations from other nodes in each GNN layer. Built-in GNN encoders will udpate the aggregated representations with an additional trainable parameters. This operation is different from the common Read-out operation, hence potentially causing worse model performance in graph-level prediction.

2c8c982e42de42eea15b4ab4cfa552f9

To mimic the Read-out operation, we can cache the aggregated representations, and clean the super node’s own representation to zeros after each GNN layer computation. Using this method, we can still leverage the built-in GraphStorm encoders, e.g., RelationalGCNEncoder and RelationalGATEncoder, but avoid the built-in self-update operation from one layer to another. In addition, we can design a more flexible Read-out function on these cached representations, other than just using the last layer’s aggregation presentations.

The below GPEncoder4SupernodeOgbg class implements this cached representations mechanism, and provides a few options for the read-out function.

[3]:
import torch as th
from graphstorm.model import (GSgnnNodeModel,
                              GSNodeEncoderInputLayer,
                              RelationalGCNEncoder,
                              RelationalGATEncoder,
                              HGTEncoder,
                              EntityClassifier,
                              ClassifyLossFunc)
from graphstorm.model.gnn_encoder_base import GraphConvEncoder

class GPEncoder4SupernodeOgbg(GraphConvEncoder):
    r"""A graph conv encoder for Graph Classification

        Unique arguments in this class:
        -------------------------------
        base_encoder: GraphStorm ConvEncoder types, options:
            1. `RelationalGCNEncoder`;
            2. `RelationalGATEncoder`;
            3. `HGTEncoder`.
        read_out_opt: string in the following options:
            The aggregation method for the cached supernodes' representations. The current options:
            1. `last_only`: only use the last layer's representations. If use this option,
               the read_out_ops will be ignored because there is only one layer representation
               is involved in the final read_out.
            2. `mean`: compute the mean of all of the cached supernode representations.
            3. `sum`: compute the summantion of all of the cached supernode representations.
            4. `weighted_sum`: use additional weight parameters to compute the weighted summation
               all of the cached supernode representations.
            5. `min`: compute the minimum in each dimension of the all of the cached supernode
               representations.
            6. `max`: compute the maximum in each dimension of the all of the cached supernode
               representations.
        super_ntype: string
            The name of supernode type. Default is 'super'.

    """
    def __init__(self,
                 h_dim,
                 out_dim,
                 base_encoder,
                 read_out_opt='last_only',
                 super_ntype='super'
                ):
        assert isinstance(base_encoder, (RelationalGCNEncoder, RelationalGATEncoder, HGTEncoder)), \
               'Only support GraphStorm\'s RelationalGCNEncoder, RelationalGATEncoder, and HGTEncoder'
        assert base_encoder.num_layers >= 3, 'For Graph Prediction task, at least two layers GNN' + \
                                       f'encoder required, but got {base_encoder.num_layers - 1} ...'
        super(GPEncoder4SupernodeOgbg, self).__init__(h_dim, out_dim, base_encoder.num_layers)

        assert read_out_opt in ['last_only', 'mean', 'sum', 'weighted_sum', 'min', 'max'], + \
                                f'Not recognized read_out_opt {read_out_opt}. Options include ' + \
                                '\'last_only\', \'mean\', \'sum\', \'weighted_sum\', \'mim\', ' + \
                                'and \'max\'.'
        self.base_encoder = base_encoder
        self.read_out_opt = read_out_opt
        self.super_ntype = super_ntype
        if read_out_opt=='weighted_sum':
            self.weighted_sum_para = th.nn.Parameter(th.Tensor(1, num_hidden_layers))
        else:
            self.weighted_sum_para = None

    def forward(self, blocks, h):

        supernode_cache = []

        # message passing in subgraphs and cache super-nodes representations
        for layer, block in zip(self.base_encoder.layers, blocks):
            h = layer(block, h)

            # 1. cache the output of supernodes in each layer
            supernode_cache.append(h[self.super_ntype])
            # 2. zero out the representations of supernodes as the next layer input
            h[self.super_ntype] = th.zeros_like(h[self.super_ntype])

        # add final read_out functions.
        supernode_cache = th.stack(supernode_cache)
        output = self._read_out_ops(supernode_cache)

        return {self.super_ntype: output}

    def _read_out_ops(self, supernode_cache):
        """ The supernode_cache shape L * N * D
            The output shape N * D
        """
        if self.read_out_opt=='last_only':
            output = supernode_cache[-1]
        elif self.read_out_opt=='mean':
            output = th.mean(supernode_cache, dim=0)
        elif self.read_out_opt=='sum':
            output = th.sum(supernode_cache, dim=0)
        elif self.read_out_opt=='weighted_sum' and self.weighted_sum_para:
            output = th.einsum('ij, jkl->kl', self.weighted_sum_para, supernode_cache)
        else:
            raise NotImplementedError('Only support last_only, mean, sum, and weighted_sum '+ \
                                      f'read_out_opt, but got {self.read_out_opt}.')

        return output

3. Training and Inference Pipeline

With the customized encoder modified for graph prediction, we can reuse GraphStorm’s end-to-end training and inference pipeline as the one in Notebook 1: Node Classification Pipeline and Notebook 2: Link Prediction Pipeline to conduct the graph classification task on the converted super-node OGBG data.

[4]:
import logging
logging.basicConfig(level=logging.INFO)

import graphstorm as gs
gs.initialize()
[5]:
ogbg_data = gs.dataloading.GSgnnData(part_config='./supernode_gs_1p/supernode_molhiv.json')
[6]:
# define dataloaders for training, validation, and testing
nfeats_4_modeling = {'node': ['n_feat'], 'super': ['n_feat']}

train_dataloader = gs.dataloading.GSgnnNodeDataLoader(
    dataset=ogbg_data,
    target_idx=ogbg_data.get_node_train_set(ntypes=['super']),
    node_feats=nfeats_4_modeling,
    label_field='labels',
    fanout=[20, 20, 20],
    batch_size=128,
    train_task=True)
val_dataloader = gs.dataloading.GSgnnNodeDataLoader(
    dataset=ogbg_data,
    target_idx=ogbg_data.get_node_val_set(ntypes=['super']),
    node_feats=nfeats_4_modeling,
    label_field='labels',
    fanout=[100, 100, 100],
    batch_size=256,
    train_task=False)
test_dataloader = gs.dataloading.GSgnnNodeDataLoader(
    dataset=ogbg_data,
    target_idx=ogbg_data.get_node_test_set(ntypes=['super']),
    node_feats=nfeats_4_modeling,
    label_field='labels',
    fanout=[100, 100, 100],
    batch_size=256,
    train_task=False)

In terms of GNN model, we can create a GraphStorm GNN model using nearly the same architecture as in the other notebooks, except that we replace the built-in GNN encoders, e.g., RelationGCNEncoder with the customized GPEncoder4SupernodeOgbg, which wraps a RelationGCNEncoder as its base encoder.

[7]:
class RgcnGCModel4SuperOgbg(GSgnnNodeModel):
    """ A customized GNN model for graph classification using Graphstorm APIs

    Arguments
    ----------
    g: DistGraph
        A DGL DistGraph.
    num_hid_layers: int
        The number of gnn layers.
    node_feat_field: dict of list of strings
        The list features for each node type to be used in the model.
    hid_size: int
        The dimension of hidden layers.
    num_classes: int
        The target number of classes for classification.
    multilabel: boolean
        Indicator of if this is a multilabel task.
    """
    def __init__(self,
                 g,
                 num_hid_layers,
                 node_feat_field,
                 hid_size,
                 num_classes,
                 multilabel=False):
        super().__init__(alpha_l2norm=0.)

        # extract feature size
        feat_size = gs.get_node_feat_size(g, node_feat_field)

        # set an input layer encoder
        encoder = GSNodeEncoderInputLayer(g=g, feat_size=feat_size, embed_size=hid_size)
        self.set_node_input_encoder(encoder)

        # set an RGCN  encoder as the base encoder
        gnn_encoder = RelationalGCNEncoder(g=g,
                                           h_dim=hid_size,
                                           out_dim=hid_size,
                                           num_hidden_layers=num_hid_layers)
        # wrap the base RGCN encoder into GPEncoder4SupernodeOgbg
        gp_encoder = GPEncoder4SupernodeOgbg(hid_size,
                                             hid_size,
                                             gnn_encoder,
                                             read_out_opt='last_only',
                                             super_ntype='super')
        self.set_gnn_encoder(gp_encoder)

        # set a decoder specific to node classification task
        decoder = EntityClassifier(in_dim=hid_size,
                                   num_classes=num_classes,
                                   multilabel=multilabel)
        self.set_decoder(decoder)

        # classification loss function
        self.set_loss_func(ClassifyLossFunc(multilabel=multilabel))

        # initialize model's optimizer
        self.init_optimizer(lr=0.001,
                            sparse_optimizer_lr=0.001,
                            weight_decay=0)

[8]:
model = RgcnNCModel4SuperOgbg(g=ogbg_data.g,
                              num_hid_layers=3,
                              node_feat_field=nfeats_4_modeling,
                              hid_size=128,
                              num_classes=2)
[9]:
# setup a classification evaluator for the trainer
evaluator = gs.eval.GSgnnClassificationEvaluator(eval_frequency=100,
                                                 eval_metric_list=['roc_auc'])
[10]:
# create a GraphStorm node task trainer for the RGCN model
trainer = gs.trainer.GSgnnNodePredictionTrainer(model)
trainer.setup_evaluator(evaluator)
trainer.setup_device(gs.utils.get_device())
[11]:
# Train the model with the trainer using fit() function
trainer.fit(train_loader=train_dataloader,
            val_loader=val_dataloader,
            test_loader=test_dataloader,
            num_epochs=50,
            save_model_path='a_save_path/')

Next, we examine the model performance over the training process

[12]:
# Extract accuracies from the trainer's evaluator:
val_accs, test_accs = [], []
for val_acc, test_acc in trainer.evaluator.history:
    val_accs.append(val_acc['roc_auc'])
    test_accs.append(test_acc['roc_auc'])
[13]:
import matplotlib.pyplot as plt

# plot the learning curves
fig, ax = plt.subplots(figsize=(15, 10))
ax.plot(val_accs, label='val')
ax.plot(test_accs, label='test')
ax.set(xlabel='Eval Times', ylabel='ROC_AUC')
ax.legend(loc='best');
[14]:
# after training, the best model is saved to disk:
best_model_path = trainer.get_best_model_path()
print('Best model checkpoint:', best_model_path)
[15]:
# check the saved artifacts
!ls -ls {best_model_path}
[16]:
# we can restore the model from the checkpoint:
model.restore_model(best_model_path)
[17]:
# Setup dataloader for inference
infer_dataloader = gs.dataloading.GSgnnNodeDataLoader(dataset=ogbg_data,
                                                      target_idx=ogbg_data.get_node_test_set(ntypes=['super']),
                                                      node_feats=nfeats_4_modeling,
                                                      label_field='labels',
                                                      fanout=[100, 100, 100],
                                                      batch_size=256,
                                                      train_task=False)
[18]:
# Create an Inferrer object
infer = gs.inference.GSgnnNodePredictionInferrer(model)

# Run inference on the inference dataset
infer.infer(infer_dataloader,
            save_embed_path='infer/embeddings',
            save_prediction_path='infer/predictions',
            use_mini_batch_infer=True)
[19]:
# The GNN embeddings on the inference graph are saved to:
!ls -lh infer/embeddings
!ls -lh infer/embeddings/super/
[20]:
!ls -lh infer/predictions
!ls -lh infer/predictions/super
[ ]: