graphstorm.model

GraphStorm provides a set of Graph Neural Network (GNN) modules. By combining them in proper ways, users can build various GNN models for different tasks.

A GNN model in GraphStorm normally contains four components:

  • Input layer: an input encoder that converts input node/edge features into embeddings with the given hidden dimensions. The output of an input layer will become the input of the GNN layer, or the decoder layer if GNN is not defined.

  • GNN layer (Optional): a GNN encoder that performs the message passing computation. The outputs of a GNN layer are embeddings of nodes that wil be used in the decoder layer.

  • Decoder layer: a task specific decoder that converts results from either a GNN layer or an input layer into loss values for different GML tasks, e.g., classification, regression, or link prediction.

  • Model optimizer: GraphStorm model classes have a built-in model optimizer, which should be initialized during GraphStorm GNN model object construction.

If users would like to implement their own GNN models, a suggested practice is to extend a base GNN model class and its corresponding interface, e.g., GSgnnNodeModelBase and GSgnnNodeModelInterface, and implement the required abstract methods.

If users just want to build their own message passing methods, a suggested practice is to create their own GNN encoders by extending the GraphConvEncoder base class, and implement the forward(self, blocks, h) function, which will be called by GraphStorm GNN model classes within their own forward() function.

For examples of how to use these GraphStorm APIs to form training/inference pipelines, to switch different GNN encoders to implement various GNN models, and to build a customized GNN encoder, please refer to GraphStorm API Programming Examples.

Base GNN models

GSgnnModelBase

GraphStorm GNN model base class.

GSgnnNodeModelBase

GraphStorm base model class for node prediction tasks.

GSgnnNodeModelInterface

The interface for GraphStorm node prediction model.

GSgnnEdgeModelBase

GraphStorm GNN model base class for edge prediction tasks.

GSgnnEdgeModelInterface

The interface for GraphStorm edge prediction model.

GSgnnLinkPredictionModelBase

GraphStorm GNN model base class for link-prediction tasks.

GSgnnLinkPredictionModelInterface

The interface for GraphStorm link prediction model.

Input Layer

GSPureLearnableInputLayer

The node encoder input layer for heterogeneous graphs

GSNodeEncoderInputLayer

The node encoder input layer for all nodes in a heterogeneous graph.

GSLMNodeEncoderInputLayer

The node encoder input layer with language model (LM) supported for all nodes in a heterogeneous graph.

GSPureLMNodeInputLayer

The node encoder input embedding layer with language model (LM) supported only.

GNN Layer

RelGraphConvLayer

Relational graph convolution layer from Modeling Relational Data with Graph Convolutional Networks.

RelationalGCNEncoder

Relational graph conv encoder.

RelationalAttLayer

Relational graph attention layer from Relational Graph Attention Networks.

RelationalGATEncoder

Relational graph attention encoder.

HGTLayer

Heterogenous graph transformer (HGT) layer from Heterogeneous Graph Transformer.

HGTLayerwithEdgeFeat

Heterogenous graph transformer (HGT) layer with edge feature supported.

HGTEncoder

Heterogenous Graph Transformer (HGT) encoder.

SAGEConv

GraphSage Convolutional layer from Inductive Representation Learning on Large Graphs.

SAGEEncoder

GraphSage Conv Encoder.

GATConv

Graph attention layer from Graph Attention Network.

GATEncoder

GAT Conv Encoder.

GATv2Conv

GATv2 Convolutional layer from How Attentive are Graph Attention Networks?.

GATv2Encoder

GATv2 Conv Encoder.

GATConvwithEdgeFeat

Graph attention layer with edge feature supported in message passing computation.

Decoder Layer

EntityClassifier

Decoder for node classification tasks.

EntityRegression

Decoder for node regression tasks.

DenseBiDecoder

Dense bi-linear decoder for edge prediction tasks.

MLPEdgeDecoder

MLP-based decoder for edge prediction tasks.

MLPEFeatEdgeDecoder

MLP-based decoder for edge prediction tasks with edge features supported.

EdgeRegression

Decoder for edge regression tasks.

LinkPredictDotDecoder

Decoder for link prediction using the dot product as the score function.

LinkPredictContrastiveDotDecoder

Decoder for link prediction designed for contrastive loss by

LinkPredictDistMultDecoder

Decoder for link prediction using the DistMult as the score function.

LinkPredictContrastiveDistMultDecoder

Decoder for link prediction designed for contrastive loss

LinkPredictRotatEDecoder

Decoder for link prediction using the RotatE as the score function.

LinkPredictContrastiveRotatEDecoder

Decoder for link prediction designed for contrastive loss using the RotatE as the score function.

LinkPredictWeightedRotatEDecoder

Link prediction decoder with the score function of RotatE with edge weight.

LinkPredictTransEDecoder

Decoder for link prediction using the TransE as the score function.

LinkPredictContrastiveTransEDecoder

Decoder for link prediction designed for contrastive loss using the TransE as the score function.

LinkPredictWeightedTransEDecoder

Link prediction decoder with the score function of TransE with edge weight.

Loss Function

ClassifyLossFunc

Loss function for classification tasks.

FocalLossFunc

Focal loss function for classification tasks.

RegressionLossFunc

Loss function for regression tasks.

ShrinkageLossFunc

Shrinkage Loss for imbalanced regression tasks.

LinkPredictBCELossFunc

Loss function for link prediction tasks using binary cross entropy loss.

WeightedLinkPredictBCELossFunc

Loss function for link prediction tasks using binary cross entropy loss with weights.

LinkPredictAdvBCELossFunc

Binary cross entropy loss function for link prediction tasks with adversarial loss for negative samples.

WeightedLinkPredictAdvBCELossFunc

Binary cross entropy loss function for link prediction tasks with adversarial loss for negative samples and weights on positive samples.

LinkPredictContrastiveLossFunc

Contrastive Loss function for link prediction tasks.