Notebook 3: Use GraphStorm APIs for Implementing Built-in GNN Models
This notebook demonstrates how to use GraphStorm APIs to implement GraphStorm built-in GNN models such as RGAT and HGT, for different tasks.
In this notebook, we use different GSgnnEncoder modules, and set the corresponding arguments in a GNN model, hence reproducing several GraphStorm built-in GNN models, such as RGAT, and HGT. Using the same pipelines demonstratred in the Notebook 1: Node Classification Pipeline and Notebook 2: Link Prediction
Pipeline, users can easily conduct node classification and link prediction task on the ACM dataset created by the Notebook 0: Data Preparation.
Prerequisites
GraphStorm. Please find more details on installation of GraphStorm.
ACM data that has been created according to the Notebook 0: Data Preparation, and is stored in the
./acm_gs_1p/folder.Installation of supporting libraries, e.g., matplotlib.
1. Revisit an RGCN model in the demo_models.py
The Notebook 1 and Notebook 2 both use RGCN models that share the same GNN model architecture defined by GraphStorm. To modify a GraphStorm GNN model, let’s first revisit an RGCN model in the demo_models.py file. For simplicity, some document strings are removed, and code are restructured to fit in notebook cells.
[15]:
import graphstorm as gs
from graphstorm.model import GSgnnNodeModel, GSNodeEncoderInputLayer, RelationalGCNEncoder, EntityClassifier, ClassifyLossFunc
class RgcnNCModel(GSgnnNodeModel):
""" A simple RGCN model for node classification using Graphstorm APIs
"""
def __init__(self, g, num_hid_layers, node_feat_field, hid_size, num_classes, multilabel=False):
super(RgcnNCModel, self).__init__(alpha_l2norm=0.)
# extract node feature dimensions
feat_size = gs.get_node_feat_size(g, node_feat_field)
# set an input encoder
encoder = GSNodeEncoderInputLayer(g=g, feat_size=feat_size, embed_size=hid_size)
self.set_node_input_encoder(encoder)
# set a GNN encoder
gnn_encoder = RelationalGCNEncoder(g=g, h_dim=hid_size, out_dim=hid_size, num_hidden_layers=num_hid_layers-1)
self.set_gnn_encoder(gnn_encoder)
# set a decoder specific to node classification task
decoder = EntityClassifier(in_dim=hid_size, num_classes=num_classes, multilabel=multilabel)
self.set_decoder(decoder)
# classification loss function
self.set_loss_func(ClassifyLossFunc(multilabel=multilabel))
# initialize model's optimizer
self.init_optimizer(lr=0.001, sparse_optimizer_lr=0.01, weight_decay=0)
1.1 GraphStorm built-in model architecture
A GraphStorm built-in model normally contains four modules:
An input encoder that converts input node features to the embeddings with hidden dimensions.
A GNN encoder that takes the embeddings from the input layer and performs message passing computation.
A decoder that is task sepcific, e.g., the
EntityClassifierfor classification tasks.A loss function that matches specific tasks, e.g., the
ClassifyLossFunc.
Besides the four modules, a GraphStorm GNN model also need to initialize its own optimizer object.
1.2 Model arguments
Each specific GNN model may has its own model arguments. Some arguments could be common for other models, like the input and output dimensions, while others may be model specific. For example, RGCN model asks for the number of bases to reduce the number of learnable parameters, and attention-based models may need to set the number of attention heads. Not only GNN models ask for arguments, GML tasks need specific arguments. For example, classification tasks may have multiple labels.
GraphStorm APIs have given default values to many arguments. For better flexibility, we can add some arguments into model initialization, such as num_hid_layers and hid_size.
1.3 GML task modules
Besides model-related modules, a GNN model also contains task-specific modules, including task specific decoders and loss functions. For example, to perform a node classification task, the above RgcnNCModel model chooses the EntityClassifier as its decoder and use the ClassifyLossFunc as its loss function.
2 Reproduce GraphStorm Built-in GNN Model Variants
With knowing the common architecture and arguments, it is easy to reproduce GraphStorm built-in GNN model variants.
2.1 Reproduce an RGAT Model for Node Classification
To turn the demo RgcnNCModel code into an RgatNCModel model, only need two modifications:
For the GNN encoder, replace the
RelationalGCNEncoderwith theRelationalGATEncoder.Add some
RelationalGATEncoderspecific arguments in initialization.
Below is the simplified code of the RgatNCModel model. The complete code can be found in the demo_models.py file.
[2]:
from graphstorm.model import RelationalGATEncoder
class RgatNCModel(GSgnnNodeModel):
""" A simple Rgat model for node classification using Graphstorm APIs
"""
def __init__(self, g,
num_heads, # an argument specific to RelationalGATEncoder
num_hid_layers, node_feat_field, hid_size, num_classes, multilabel=False):
super(RgatNCModel, self).__init__(alpha_l2norm=0.)
# input encoder remains the same ......
# set a GNN encoder
gnn_encoder = RelationalGATEncoder(g=g, h_dim=hid_size, out_dim=hid_size,
num_heads=num_heads, # pass the num_heads to the RelationalGATEncoder
num_hidden_layers=num_hid_layers-1)
self.set_gnn_encoder(gnn_encoder)
# decoder, loss function, and optimizer initialization remain the same ......
2.2 Reproduce an HGT Model with DistMult Decoder for Link Prediction
Similar as the RGAT variant, replacement of the RelationalGCNEncoder with the HGTEncoder and setting up corresponding arguments can reproduce an HGT model. In addition, this example also replaces the LinkPredictDotDecoder decoder with the LinkPredictDistMultDecoder, and sets its own arguments. Below is the simplified code of the HgtLPModel model. The complete code can be found in the demo_models.py file.
[3]:
from graphstorm.model import GSgnnLinkPredictionModel, HGTEncoder, LinkPredictDistMultDecoder
class HgtLPModel(GSgnnLinkPredictionModel):
""" A simple HGT model for link prediction using Graphstorm APIs
"""
def __init__(self, g,
num_heads, # an argument specific to HGTEncoder
num_hid_layers, node_feat_field, hid_size):
super(HgtLPModel, self).__init__(alpha_l2norm=0.)
# input encoder remains the same ......
# set a GNN encoder
gnn_encoder = HGTEncoder(g=g,
num_heads=num_heads, # pass the num_heads to the HGTEncoder
hid_dim=hid_size, out_dim=hid_size, num_hidden_layers=num_hid_layers-1)
self.set_gnn_encoder(gnn_encoder)
# set a decoder specific to link prediction task
decoder = LinkPredictDistMultDecoder(etypes=g.canonical_etypes, # specificly added to the LinkPredictDistMultDecoder
h_dim=hid_size)
self.set_decoder(decoder)
# loss function, and optimizer initialization remain the same ......
3. Link Prediciton Pipeline by Using the HGT Model
To use the above mentioned GNN model variant, the overall GML pipeline only needs very few modifications that adapt to model specific arugments. Below example reuses the link prediction pipeline of the Notebook 2. For simplisity, this example combines multiple cells, and comments.
3.1 Training pipeline
[16]:
import logging
import graphstorm as gs
logging.basicConfig(level=20)
gs.initialize()
nfeats_4_modeling = {'author':['feat'], 'paper':['feat'],'subject':['feat']}
acm_data = gs.dataloading.GSgnnData(part_config='./acm_gs_1p/acm.json', node_feat_field=nfeats_4_modeling)
train_dataloader = gs.dataloading.GSgnnLinkPredictionDataLoader(
dataset=acm_data,
target_idx=acm_data.get_edge_train_set(etypes=[('paper', 'citing', 'paper')]),
fanout=[20, 20],
num_negative_edges=10,
node_feats=nfeats_4_modeling,
batch_size=64,
exclude_training_targets=True,
reverse_edge_types_map={("paper", "citing", "paper"):("paper","cited","paper")},
train_task=True)
val_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(
dataset=acm_data,
target_idx=acm_data.get_edge_val_set(etypes=[('paper', 'citing', 'paper')]),
fanout=[100, 100],
num_negative_edges=100,
node_feats=nfeats_4_modeling,
batch_size=256)
test_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(
dataset=acm_data,
target_idx=acm_data.get_edge_test_set(etypes=[('paper', 'citing', 'paper')]),
fanout=[100, 100],
num_negative_edges=100,
node_feats=nfeats_4_modeling,
batch_size=256)
from demo_models import HgtLPModel # Import the HGT model variant
model = HgtLPModel(g=acm_data.g,
num_heads=8,
num_hid_layers=2,
node_feat_field=nfeats_4_modeling,
hid_size=128)
evaluator = gs.eval.GSgnnMrrLPEvaluator(eval_frequency=1000)
trainer = gs.trainer.GSgnnLinkPredictionTrainer(model, topk_model_to_save=1)
trainer.setup_evaluator(evaluator)
trainer.setup_device(gs.utils.get_device())
trainer.fit(train_loader=train_dataloader,
val_loader=val_dataloader,
test_loader=test_dataloader,
num_epochs=50,
save_model_path='a_save_path/',
save_model_frequency=1000,
use_mini_batch_infer=True)
3.2 Visualize Model Performance History
[18]:
import matplotlib.pyplot as plt
val_metrics, test_metrics = [], []
for val_metric, test_metric in trainer.evaluator.history:
val_metrics.append(val_metric['mrr'])
test_metrics.append(test_metric['mrr'])
fig, ax = plt.subplots()
ax.plot(val_metrics, label='val')
ax.plot(test_metrics, label='test')
ax.set(xlabel='Epoch', ylabel='Mrr')
ax.legend(loc='best')
3.3 Inference pipeline
[19]:
best_model_path = trainer.get_best_model_path()
print('Best model path:', best_model_path)
model.restore_model(best_model_path)
infer_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(
dataset=acm_data,
target_idx=acm_data.get_edge_infer_set(etypes=[('paper', 'citing', 'paper')]),
fanout=[100, 100],
num_negative_edges=100,
node_feats=nfeats_4_modeling,
batch_size=256)
infer = gs.inference.GSgnnLinkPredictionInferrer(model)
infer.infer(acm_data,
infer_dataloader,
save_embed_path='infer/embeddings',
use_mini_batch_infer=True)
[ ]: