Notebook 4: Use GraphStorm APIs for Customizing Model Components
This notebook provides an example about how to customize components of GML models to fit specific requirements. The customized models should extend GraphStorm higher-level APIs, which enable them to not only implement their own functionalities, but also to easily integrate into GraphStorm training and inference pipelines.
An Example of a Customized Model
A widely used GNN model is the RGAT model, proposed by Relational Graph Attention Networks. The original RGAT model considers the different importance of neighbors for a node and leverages attention mechanism to aggregate messages from neighbors within the same relation type. Then, aggregations of neighbors from different relation types are added together as the output representations of a node,
An alternative way to aggregate representations across different relation types is to use attention instead of summation. We can use an additional weight set to compute the attention coefficients for different relation types,
and then compute the weighted sum of aggregations across relation types,
In this notebook, we will implement this Across Relation Attention GAT (ARA_GAT), fit it into GraphStorm model architecture, and run training and inference using existing pipelines.
Prerequisites
GraphStorm. Please find more details on installation of GraphStorm.
ACM data that has been created according to Notebook 0: Data Preparation, and is stored in the
./acm_gs_1p/folder.Installation of supporting libraries, e.g., matplotlib.
1. Recap GraphStorm Model Architecture
As explained in Notebook 3: Use GraphStorm APIs for Implementing Built-in GNN Models, a GraphStorm model normally contains four modules:
An input encoder that converts input node features to the embeddings with hidden dimensions.
A GNN encoder that takes the embeddings from the input encoder and performs message passing computation.
A decoder that is task sepcific, e.g., the
EntityClassifierfor classification tasks.A loss function that matches specific tasks, e.g., the
ClassifyLossFunc.
Given this architecture, it is clear that we only need to build a customized GNN encoder that implements the ARA_GAT variant and leave the other modules untouched.
2. ARA_GAT Variant Encoder Implementation
To build a customized GNN encoder, we can refer to implementation of GraphStorm’s GNN encoders, such as graphstorm.model.RelationalGATEncoder, which extends the graphstorm.model.GraphConvEncoder and implements the required method.
The code in the cells below includes a layer module named Ara_GatLayer, which fulfills the ARA_GAT functions in one layer of GNN, and an encoder module named Ara_GatEncoder, which extends from graphstorm.model.GSgnnNodeModel.
2.1 Ara_GatLaye Implementation
[5]:
import dgl
import torch as th
import torch.nn as nn
class Ara_GatLayer(nn.Module):
""" One layer of ARA_GAT
"""
def __init__(self, in_dim, out_dim, num_heads, rel_names, bias=True,
activation=None, self_loop=False, dropout=0.0, norm=None):
super(Ara_GatLayer, self).__init__()
self.bias = bias
self.activation = activation
self.self_loop = self_loop
self.leaky_relu = nn.LeakyReLU(0.2)
# GAT module for each relation type
self.rel_gats = nn.ModuleDict()
for rel in rel_names:
self.rel_gats[str(rel)] = dgl.nn.GATConv(in_dim, out_dim//num_heads, # should be divible
num_heads, allow_zero_in_degree=True)
# across-relation attention weight set
self.acr_attn_weights = nn.Parameter(th.Tensor(out_dim, 1))
nn.init.normal_(self.acr_attn_weights)
# bias
if bias:
self.h_bias = nn.Parameter(th.Tensor(out_dim))
nn.init.zeros_(self.h_bias)
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_dim, out_dim))
nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain('relu'))
# dropout
self.dropout = nn.Dropout(dropout)
# normalization for each node type
ntypes = set()
for rel in rel_names:
ntypes.add(rel[0])
ntypes.add(rel[2])
if norm == "batch":
self.norm = nn.ParameterDict({ntype:nn.BatchNorm1d(out_dim) for ntype in ntypes})
elif norm == "layer":
self.norm = nn.ParameterDict({ntype:nn.LayerNorm(out_dim) for ntype in ntypes})
else:
self.norm = None
def forward(self, g, inputs):
"""
g: DGL.block
A DGL block
inputs : dict[str, torch.Tensor]
Node feature for each node type.
Returns
-------
dict[str, torch.Tensor]
New node features for each node type.
"""
g = g.local_var()
# loop each edge type to fulfill GAT computation within each edge type
for src_type, e_type, dst_type in g.canonical_etypes:
# extract subgraph of each edge type
sub_graph = g[src_type, e_type, dst_type]
# check if no edges exist for this edge type
if sub_graph.num_edges() == 0:
continue
# extract source and destination node features
src_feat = inputs[src_type]
dst_feat = inputs[dst_type][ :sub_graph.num_dst_nodes()]
# GAT in one relation type
agg = self.rel_gats[str((src_type, e_type, dst_type))](sub_graph, (src_feat, dst_feat))
agg = agg.view(agg.shape[0], -1)
# store aggregations in destination nodes
sub_graph.dstdata['agg_' + str((src_type, e_type, dst_type))] = self.leaky_relu(agg)
h = {}
for n_type in g.dsttypes:
if g.num_dst_nodes(n_type) == 0:
continue
# cross relation attention enhancement as outputs
agg_list = []
for k, v in g.dstnodes[n_type].data.items():
if k.startswith('agg_'):
agg_list.append(v)
# cross-relation attention
if agg_list:
acr_agg = th.stack(agg_list, dim=1)
acr_att = th.matmul(acr_agg, self.acr_attn_weights)
acr_sfm = th.softmax(acr_att, dim=1)
# cross-relation weighted aggregation
acr_sum = (acr_agg * acr_sfm).sum(dim=1)
elif not self.self_loop:
raise ValueError(f'Some nodes in the {n_type} type have no in-degree.' + \
'Please check the data or set \"self_loop=True\"')
# process new features
if self.self_loop:
if agg_list:
h_n = acr_sum + th.matmul(inputs[n_type][ :g.num_dst_nodes(n_type)], self.loop_weight)
else:
h_n = th.matmul(inputs[n_type][ :g.num_dst_nodes(n_type)], self.loop_weight)
if self.bias:
h_n = h_n + self.h_bias
if self.activation:
h_n = self.activation(h_n)
if self.norm:
h_n = self.norm[n_type](h_n)
h_n = self.dropout(h_n)
h[n_type] = h_n
return h
2.2 Ara_GatEncoder Implementation
Here, we implement the Ara_GatEncoder by extending the base GraphStorm encoder, graphstorm.model.gnn_encoder_base.GraphConvEncoder, and implementing the forward(self, blocks, h) funciton to make this class compatible with GraphStorm model architecture. The forward() function takes a DGL block list and a dictionary of node representations as input arguments, and returns a dictionary that contains the new node representations. This forward() function will be called by GraphStorm model
classes within their own forward() function.
[6]:
from graphstorm.model.gnn_encoder_base import GraphConvEncoder
import torch.nn.functional as F
class Ara_GatEncoder(GraphConvEncoder):
""" Across Relation Attention GAT Encoder by extending Graphstorm APIs
"""
def __init__(self, g, h_dim, out_dim, num_heads, num_hidden_layers=1,
dropout=0, use_self_loop=True, norm='batch'):
super(Ara_GatEncoder, self).__init__(h_dim, out_dim, num_hidden_layers)
# h2h
for _ in range(num_hidden_layers):
self.layers.append(Ara_GatLayer(h_dim, h_dim, num_heads, g.canonical_etypes,
activation=F.relu, self_loop=use_self_loop, dropout=dropout, norm=norm))
# h2o
self.layers.append(Ara_GatLayer(h_dim, out_dim, num_heads, g.canonical_etypes,
activation=F.relu, self_loop=use_self_loop, norm=norm))
def forward(self, blocks, h):
""" accept block list and feature dictionary as inputs
"""
for layer, block in zip(self.layers, blocks):
h = layer(block, h)
return h
3. Build a Node Classification Model based on the Ara_GatEncoder
The RgatNCModel below follows the same node classification model architecture used in Notebook 1: Use GraphStorm APIs for Building a Node Classification Pipeline. For the GNN encoder components, this model provides the option to use either the Ara_GatEncoder or the built-in RelationalGATEncoder from GraphStorm.
[20]:
from graphstorm.model import GSgnnNodeModel, GSNodeEncoderInputLayer, RelationalGATEncoder, EntityClassifier, ClassifyLossFunc
class RgatNCModel(GSgnnNodeModel):
""" A customized RGAT model for node classification using Graphstorm APIs
"""
def __init__(self, g, num_heads, num_hid_layers, node_feat_field, hid_size, num_classes, multilabel=False,
encoder_type='ara' # option for different rgat encoders
):
super(RgatNCModel, self).__init__(alpha_l2norm=0.)
# extract feature size
feat_size = gs.get_node_feat_size(g, node_feat_field)
# set an input layer encoder
encoder = GSNodeEncoderInputLayer(g=g, feat_size=feat_size, embed_size=hid_size)
self.set_node_input_encoder(encoder)
# set the option of using either customized RGAT or built-in RGAT encoder
if encoder_type == 'ara':
gnn_encoder = Ara_GatEncoder(g, hid_size, hid_size, num_heads,
num_hidden_layers=num_hid_layers-1)
elif encoder_type == 'rgat':
gnn_encoder = RelationalGATEncoder(g, hid_size, hid_size, num_heads,
num_hidden_layers=num_hid_layers-1)
else:
raise Exception(f'Not supported encoders \"{encoder_type}\".')
self.set_gnn_encoder(gnn_encoder)
# set a decoder specific to node classification task
decoder = EntityClassifier(in_dim=hid_size, num_classes=num_classes, multilabel=multilabel)
self.set_decoder(decoder)
# classification loss function
self.set_loss_func(ClassifyLossFunc(multilabel=multilabel))
# initialize model's optimizer
self.init_optimizer(lr=0.001, sparse_optimizer_lr=0.01, weight_decay=0)
4. Node Classification Pipeline Using the Ara_GatNCModel Model
The overall pipeline for using the customized model for node classification tasks is identical to those in Notebook 1: Use GraphStorm APIs for Building a Node Classification Pipeline.
4.1 Training pipeline
[1]:
import logging
logging.basicConfig(level=20)
import graphstorm as gs
gs.initialize()
acm_data = gs.dataloading.GSgnnData(part_config='./acm_gs_1p/acm.json')
nfeats_4_modeling = {'author':['feat'], 'paper':['feat'],'subject':['feat']}
train_dataloader = gs.dataloading.GSgnnNodeDataLoader(
dataset=acm_data,
target_idx=acm_data.get_node_train_set(ntypes=['paper']),
node_feats=nfeats_4_modeling,
label_field='label',
fanout=[20, 20],
batch_size=64,
train_task=True)
val_dataloader = gs.dataloading.GSgnnNodeDataLoader(
dataset=acm_data,
target_idx=acm_data.get_node_val_set(ntypes=['paper']),
node_feats=nfeats_4_modeling,
label_field='label',
fanout=[100, 100],
batch_size=256,
train_task=False)
test_dataloader = gs.dataloading.GSgnnNodeDataLoader(
dataset=acm_data,
target_idx=acm_data.get_node_test_set(ntypes=['paper']),
node_feats=nfeats_4_modeling,
label_field='label',
fanout=[100, 100],
batch_size=256,
train_task=False)
model = RgatNCModel(g=acm_data.g, num_heads=8, num_hid_layers=2, node_feat_field=nfeats_4_modeling,
hid_size=128, num_classes=14, encoder_type='ara')
evaluator = gs.eval.GSgnnClassificationEvaluator(eval_frequency=100)
trainer = gs.trainer.GSgnnNodePredictionTrainer(model)
trainer.setup_evaluator(evaluator)
trainer.setup_device(gs.utils.get_device())
trainer.fit(train_loader=train_dataloader,
val_loader=val_dataloader,
test_loader=test_dataloader,
num_epochs=50,
save_model_path='a_save_path/')
4.2 Visualize Model Performance History
[11]:
import matplotlib.pyplot as plt
val_metrics, test_metrics = [], []
for val_metric, test_metric in trainer.evaluator.history:
val_metrics.append(val_metric['accuracy'])
test_metrics.append(test_metric['accuracy'])
fig, ax = plt.subplots()
ax.plot(val_metrics, label='val')
ax.plot(test_metrics, label='test')
ax.set(xlabel='Epoch', ylabel='Accuracy')
ax.legend(loc='best')
4.3 Inference pipeline
[12]:
best_model_path = trainer.get_best_model_path()
print('Best model path:', best_model_path)
model.restore_model(best_model_path)
infer_dataloader = gs.dataloading.GSgnnNodeDataLoader(dataset=acm_data,
target_idx=acm_data.get_node_test_set(ntypes=['paper']),
node_feats=nfeats_4_modeling,
label_field='label',
fanout=[100, 100],
batch_size=256,
train_task=False)
infer = gs.inference.GSgnnNodePredictionInferrer(model)
infer.infer(infer_dataloader,
save_embed_path='infer/embeddings',
save_prediction_path='infer/predictions',
use_mini_batch_infer=True)