Notebook 2: Use GraphStorm APIs for Building a Link Prediction Pipeline
This notebook demonstrates how to use GraphStorm’s APIs to create a graph machine learning pipeline for a link prediction task.
In this notebook, we modify the RGCN model used in the Notebook 1 to adapt to link prediction tasks and use it to conduct link prediction on the ACM dataset created by the Notebook_0_Data_Prepare.
Prerequisites
GraphStorm. Please find more details on installation of GraphStorm.
ACM data that has been created according to the Notebook 0: Data Preparation, and is stored in the
./acm_gs_1p/folder.Installation of supporting libraries, e.g., matplotlib.
[1]:
# Setup log level in Jupyter Notebook
import logging
logging.basicConfig(level=20)
The major steps of creating a link prediction pipeline are same as the node classification pipeline in the Notebook 1. In this notebook, we will only highlight the different components for clarity.
0. Initialize the GraphStorm Standalone Environment
[1]:
import graphstorm as gs
gs.initialize()
1. Setup GraphStorm Dataset and DataLoaders
[3]:
nfeats_4_modeling = {'author':['feat'], 'paper':['feat'],'subject':['feat']}
# create a GraphStorm Dataset for the ACM graph data generated in the Notebook 0
acm_data = gs.dataloading.GSgnnData(part_config='./acm_gs_1p/acm.json', node_feat_field=nfeats_4_modeling)
Because link prediction needs both positive and negative edges for training, we use GraphStorm’s GSgnnLinkPredictionDataloader which is dedicated for link prediction dataloading. This class takes some common arugments as these NodePredictionDataloaders, such as dataset, target_idx, node_feats, and batch_size. It also takes some link prediction-related arguments, e.g., num_negative_edges, exlude_training_targets, and etc.
[2]:
# define dataloaders for training and validation
train_dataloader = gs.dataloading.GSgnnLinkPredictionDataLoader(
dataset=acm_data,
target_idx=acm_data.get_edge_train_set(etypes=[('paper', 'citing', 'paper')]),
fanout=[20, 20],
num_negative_edges=10,
node_feats=nfeats_4_modeling,
batch_size=64,
exclude_training_targets=True,
reverse_edge_types_map={("paper", "citing", "paper"):("paper","cited","paper")},
train_task=True)
val_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(
dataset=acm_data,
target_idx=acm_data.get_edge_val_set(etypes=[('paper', 'citing', 'paper')]),
fanout=[100, 100],
num_negative_edges=100,
node_feats=nfeats_4_modeling,
batch_size=256)
test_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(
dataset=acm_data,
target_idx=acm_data.get_edge_test_set(etypes=[('paper', 'citing', 'paper')]),
fanout=[100, 100],
num_negative_edges=100,
node_feats=nfeats_4_modeling,
batch_size=256)
2. Create a GraphStorm-compatible RGCN Model for Link Prediction
For the link prediction task, we modified the RGCN model used for node classification to adopt to link prediction task. Users can find the details in the demon_models.py file.
[5]:
# import a simplified RGCN model for node classification
from demo_models import RgcnLPModel
model = RgcnLPModel(g=acm_data.g,
num_hid_layers=2,
node_feat_field=nfeats_4_modeling,
hid_size=128)
3. Setup a GraphStorm Evaluator
Here we change evaluator to a GSgnnMrrLPEvaluator that uses “mrr” as the metric dedicated for evaluation of link prediction performance.
[6]:
# setup a link prediction evaluator for the trainer
evaluator = gs.eval.GSgnnMrrLPEvaluator(eval_frequency=1000)
4. Setup a Trainer and Training
GraphStorm has the GSgnnLinkPredictionTrainer for link prediction training loop. The way of constructing this trainer and calling fit() method are same as the GSgnnNodePredictionTrainer used in the Notebook 1.
[7]:
# create a GraphStorm link prediction task trainer for the RGCN model
trainer = gs.trainer.GSgnnLinkPredictionTrainer(model, topk_model_to_save=1)
trainer.setup_evaluator(evaluator)
trainer.setup_device(gs.utils.get_device())
[3]:
# Train the model with the trainer using fit() function
trainer.fit(train_loader=train_dataloader,
val_loader=val_dataloader,
test_loader=test_dataloader,
num_epochs=50,
save_model_path='a_save_path/',
save_model_frequency=1000,
use_mini_batch_infer=True)
(Optional) 5. Visualize Model Performance History
Same as the node classification pipeline, we can use the history stored in the evaluator.
[4]:
import matplotlib.pyplot as plt
# extract evaluation history of metrics from the trainer's evaluator:
val_metrics, test_metrics = [], []
for val_metric, test_metric in trainer.evaluator.history:
val_metrics.append(val_metric['mrr'])
test_metrics.append(test_metric['mrr'])
# plot the performance curves
fig, ax = plt.subplots()
ax.plot(val_metrics, label='val')
ax.plot(test_metrics, label='test')
ax.set(xlabel='Epoch', ylabel='Mrr')
ax.legend(loc='best')
6. Inference with the Trained Model
The operations of model restore are same as those used in the Notebook 1. Users can find the best model path first, and use model’s restore_model() to load the trained model file.
[1]:
# after training, the best model is saved to disk:
best_model_path = trainer.get_best_model_path()
print('Best model path:', best_model_path)
[5]:
# we can restore the model from the saved path using the model's restore_model() function.
model.restore_model(best_model_path)
To do inference, users can either create a new dataloader as the following code does, or reuse one of the dataloaders defined in training.
[12]:
# Setup dataloader for inference
infer_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(
dataset=acm_data,
target_idx=acm_data.get_edge_infer_set(etypes=[('paper', 'citing', 'paper')]),
fanout=[100, 100],
num_negative_edges=100,
node_feats=nfeats_4_modeling,
batch_size=256)
Now we can define a GSgnnLinkPredictionInferrer by giving the restored model and do inference by calling its infer() method.
[6]:
# Create an Inferrer object
infer = gs.inference.GSgnnLinkPredictionInferrer(model)
# Run inference on the inference dataset
infer.infer(acm_data,
infer_dataloader,
save_embed_path='infer/embeddings',
use_mini_batch_infer=True)
For link prediction task, the inference outputs are embeddings of all nodes in the inference graph.
[7]:
# The GNN embeddings of all nodes in the inference graph are saved to the folder named after the target_ntype
!ls -lh infer/embeddings/paper