Notebook 3: Use GraphStorm APIs for Implementing Built-in GNN Models

This notebook demonstrates how to use GraphStorm APIs to implement GraphStorm built-in GNN models such as RGAT and HGT, for different tasks.

In this notebook, we use different GSgnnEncoder modules, and set the corresponding arguments in a GNN model, hence reproducing several GraphStorm built-in GNN models, such as RGAT, and HGT. Using the same pipelines demonstratred in the Notebook 1: Node Classification Pipeline and Notebook 2: Link Prediction Pipeline, users can easily conduct node classification and link prediction task on the ACM dataset created by the Notebook 0: Data Preparation.

Prerequisites


1. Revisit an RGCN model in the demo_models.py

The Notebook 1 and Notebook 2 both use RGCN models that share the same GNN model architecture defined by GraphStorm. To modify a GraphStorm GNN model, let’s first revisit an RGCN model in the demo_models.py file. For simplicity, some document strings are removed, and code are restructured to fit in notebook cells.

[15]:
import graphstorm as gs
from graphstorm.model import GSgnnNodeModel, GSNodeEncoderInputLayer, RelationalGCNEncoder, EntityClassifier, ClassifyLossFunc

class RgcnNCModel(GSgnnNodeModel):
    """ A simple RGCN model for node classification using Graphstorm APIs
    """
    def __init__(self, g, num_hid_layers, node_feat_field, hid_size, num_classes, multilabel=False):
        super(RgcnNCModel, self).__init__(alpha_l2norm=0.)

        # extract node feature dimensions
        feat_size = gs.get_node_feat_size(g, node_feat_field)

        # set an input encoder
        encoder = GSNodeEncoderInputLayer(g=g, feat_size=feat_size, embed_size=hid_size)
        self.set_node_input_encoder(encoder)

        # set a GNN encoder
        gnn_encoder = RelationalGCNEncoder(g=g, h_dim=hid_size, out_dim=hid_size, num_hidden_layers=num_hid_layers-1)
        self.set_gnn_encoder(gnn_encoder)

        # set a decoder specific to node classification task
        decoder = EntityClassifier(in_dim=hid_size, num_classes=num_classes, multilabel=multilabel)
        self.set_decoder(decoder)

        # classification loss function
        self.set_loss_func(ClassifyLossFunc(multilabel=multilabel))

        # initialize model's optimizer
        self.init_optimizer(lr=0.001, sparse_optimizer_lr=0.01, weight_decay=0)

1.1 GraphStorm built-in model architecture

A GraphStorm built-in model normally contains four modules:

  • An input encoder that converts input node features to the embeddings with hidden dimensions.

  • A GNN encoder that takes the embeddings from the input layer and performs message passing computation.

  • A decoder that is task sepcific, e.g., the EntityClassifier for classification tasks.

  • A loss function that matches specific tasks, e.g., the ClassifyLossFunc.

Besides the four modules, a GraphStorm GNN model also need to initialize its own optimizer object.

1.2 Model arguments

Each specific GNN model may has its own model arguments. Some arguments could be common for other models, like the input and output dimensions, while others may be model specific. For example, RGCN model asks for the number of bases to reduce the number of learnable parameters, and attention-based models may need to set the number of attention heads. Not only GNN models ask for arguments, GML tasks need specific arguments. For example, classification tasks may have multiple labels.

GraphStorm APIs have given default values to many arguments. For better flexibility, we can add some arguments into model initialization, such as num_hid_layers and hid_size.

1.3 GML task modules

Besides model-related modules, a GNN model also contains task-specific modules, including task specific decoders and loss functions. For example, to perform a node classification task, the above RgcnNCModel model chooses the EntityClassifier as its decoder and use the ClassifyLossFunc as its loss function.


2 Reproduce GraphStorm Built-in GNN Model Variants

With knowing the common architecture and arguments, it is easy to reproduce GraphStorm built-in GNN model variants.

2.1 Reproduce an RGAT Model for Node Classification

To turn the demo RgcnNCModel code into an RgatNCModel model, only need two modifications:

  1. For the GNN encoder, replace the RelationalGCNEncoder with the RelationalGATEncoder.

  2. Add some RelationalGATEncoder specific arguments in initialization.

Below is the simplified code of the RgatNCModel model. The complete code can be found in the demo_models.py file.

[2]:
from graphstorm.model import RelationalGATEncoder

class RgatNCModel(GSgnnNodeModel):
    """ A simple Rgat model for node classification using Graphstorm APIs
    """
    def __init__(self, g,
                 num_heads,    # an argument specific to RelationalGATEncoder
                 num_hid_layers, node_feat_field, hid_size, num_classes, multilabel=False):
        super(RgatNCModel, self).__init__(alpha_l2norm=0.)

        # input encoder remains the same ......

        # set a GNN encoder
        gnn_encoder = RelationalGATEncoder(g=g, h_dim=hid_size, out_dim=hid_size,
                                           num_heads=num_heads,    # pass the num_heads to the RelationalGATEncoder
                                           num_hidden_layers=num_hid_layers-1)
        self.set_gnn_encoder(gnn_encoder)

        # decoder, loss function, and optimizer initialization remain the same ......

3.1 Training pipeline

[16]:
import logging
import graphstorm as gs

logging.basicConfig(level=20)
gs.initialize()

nfeats_4_modeling = {'author':['feat'], 'paper':['feat'],'subject':['feat']}

acm_data = gs.dataloading.GSgnnData(part_config='./acm_gs_1p/acm.json', node_feat_field=nfeats_4_modeling)

train_dataloader = gs.dataloading.GSgnnLinkPredictionDataLoader(
    dataset=acm_data,
    target_idx=acm_data.get_edge_train_set(etypes=[('paper', 'citing', 'paper')]),
    fanout=[20, 20],
    num_negative_edges=10,
    node_feats=nfeats_4_modeling,
    batch_size=64,
    exclude_training_targets=True,
    reverse_edge_types_map={("paper", "citing", "paper"):("paper","cited","paper")},
    train_task=True)
val_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(
    dataset=acm_data,
    target_idx=acm_data.get_edge_val_set(etypes=[('paper', 'citing', 'paper')]),
    fanout=[100, 100],
    num_negative_edges=100,
    node_feats=nfeats_4_modeling,
    batch_size=256)
test_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(
    dataset=acm_data,
    target_idx=acm_data.get_edge_test_set(etypes=[('paper', 'citing', 'paper')]),
    fanout=[100, 100],
    num_negative_edges=100,
    node_feats=nfeats_4_modeling,
    batch_size=256)

from demo_models import HgtLPModel    # Import the HGT model variant

model = HgtLPModel(g=acm_data.g,
                   num_heads=8,
                   num_hid_layers=2,
                   node_feat_field=nfeats_4_modeling,
                   hid_size=128)

evaluator = gs.eval.GSgnnMrrLPEvaluator(eval_frequency=1000)

trainer = gs.trainer.GSgnnLinkPredictionTrainer(model, topk_model_to_save=1)
trainer.setup_evaluator(evaluator)
trainer.setup_device(gs.utils.get_device())

trainer.fit(train_loader=train_dataloader,
            val_loader=val_dataloader,
            test_loader=test_dataloader,
            num_epochs=50,
            save_model_path='a_save_path/',
            save_model_frequency=1000,
            use_mini_batch_infer=True)

3.2 Visualize Model Performance History

[18]:
import matplotlib.pyplot as plt

val_metrics, test_metrics = [], []
for val_metric, test_metric in trainer.evaluator.history:
    val_metrics.append(val_metric['mrr'])
    test_metrics.append(test_metric['mrr'])

fig, ax = plt.subplots()
ax.plot(val_metrics, label='val')
ax.plot(test_metrics, label='test')
ax.set(xlabel='Epoch', ylabel='Mrr')
ax.legend(loc='best')

3.3 Inference pipeline

[19]:
best_model_path = trainer.get_best_model_path()
print('Best model path:', best_model_path)

model.restore_model(best_model_path)

infer_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(
    dataset=acm_data,
    target_idx=acm_data.get_edge_infer_set(etypes=[('paper', 'citing', 'paper')]),
    fanout=[100, 100],
    num_negative_edges=100,
    node_feats=nfeats_4_modeling,
    batch_size=256)

infer = gs.inference.GSgnnLinkPredictionInferrer(model)

infer.infer(acm_data,
            infer_dataloader,
            save_embed_path='infer/embeddings',
            use_mini_batch_infer=True)
[ ]: