graphstorm.model
GraphStorm provides a set of Graph Neural Network (GNN) modules. By combining them in proper ways, users can build various GNN models for different tasks.
A GNN model in GraphStorm normally contains four components:
Input layer: an input encoder that converts input node/edge features into embeddings with the given hidden dimensions. The output of an input layer will become the input of the GNN layer, or the decoder layer if GNN is not defined.
GNN layer (Optional): a GNN encoder that performs the message passing computation. The outputs of a GNN layer are embeddings of nodes that wil be used in the decoder layer.
Decoder layer: a task specific decoder that converts results from either a GNN layer or an input layer into loss values for different GML tasks, e.g., classification, regression, or link prediction.
Model optimizer: GraphStorm model classes have a built-in model optimizer, which should be initialized during GraphStorm GNN model object construction.
If users would like to implement their own GNN models, a suggested practice is to extend a base GNN model class and its corresponding interface, e.g.,
GSgnnNodeModelBaseandGSgnnNodeModelInterface, and implement the required abstract methods.If users just want to build their own message passing methods, a suggested practice is to create their own GNN encoders by extending the
GraphConvEncoderbase class, and implement theforward(self, blocks, h)function, which will be called by GraphStorm GNN model classes within their ownforward()function.For examples of how to use these GraphStorm APIs to form training/inference pipelines, to switch different GNN encoders to implement various GNN models, and to build a customized GNN encoder, please refer to GraphStorm API Programming Examples.
Base GNN models
GraphStorm GNN model base class. |
|
GraphStorm base model class for node prediction tasks. |
|
The interface for GraphStorm node prediction model. |
|
GraphStorm GNN model base class for edge prediction tasks. |
|
The interface for GraphStorm edge prediction model. |
|
GraphStorm GNN model base class for link-prediction tasks. |
|
The interface for GraphStorm link prediction model. |
Input Layer
The node encoder input layer for heterogeneous graphs |
|
The node encoder input layer for all nodes in a heterogeneous graph. |
|
The node encoder input layer with language model (LM) supported for all nodes in a heterogeneous graph. |
|
The node encoder input embedding layer with language model (LM) supported only. |
GNN Layer
Relational graph convolution layer from Modeling Relational Data with Graph Convolutional Networks. |
|
Relational graph conv encoder. |
|
Relational graph attention layer from Relational Graph Attention Networks. |
|
Relational graph attention encoder. |
|
Heterogenous graph transformer (HGT) layer from Heterogeneous Graph Transformer. |
|
Heterogenous graph transformer (HGT) layer with edge feature supported. |
|
Heterogenous Graph Transformer (HGT) encoder. |
|
GraphSage Convolutional layer from Inductive Representation Learning on Large Graphs. |
|
GraphSage Conv Encoder. |
|
Graph attention layer from Graph Attention Network. |
|
GAT Conv Encoder. |
|
GATv2 Convolutional layer from How Attentive are Graph Attention Networks?. |
|
GATv2 Conv Encoder. |
|
Graph attention layer with edge feature supported in message passing computation. |
Decoder Layer
Decoder for node classification tasks. |
|
Decoder for node regression tasks. |
|
Dense bi-linear decoder for edge prediction tasks. |
|
MLP-based decoder for edge prediction tasks. |
|
MLP-based decoder for edge prediction tasks with edge features supported. |
|
Decoder for edge regression tasks. |
|
Decoder for link prediction using the dot product as the score function. |
|
Decoder for link prediction designed for contrastive loss by |
|
Decoder for link prediction using the DistMult as the score function. |
|
Decoder for link prediction designed for contrastive loss |
|
Decoder for link prediction using the RotatE as the score function. |
|
Decoder for link prediction designed for contrastive loss using the RotatE as the score function. |
|
Link prediction decoder with the score function of RotatE with edge weight. |
|
Decoder for link prediction using the TransE as the score function. |
|
Decoder for link prediction designed for contrastive loss using the TransE as the score function. |
|
Link prediction decoder with the score function of TransE with edge weight. |
Loss Function
Loss function for classification tasks. |
|
Focal loss function for classification tasks. |
|
Loss function for regression tasks. |
|
Shrinkage Loss for imbalanced regression tasks. |
|
Loss function for link prediction tasks using binary cross entropy loss. |
|
Loss function for link prediction tasks using binary cross entropy loss with weights. |
|
Binary cross entropy loss function for link prediction tasks with adversarial loss for negative samples. |
|
Binary cross entropy loss function for link prediction tasks with adversarial loss for negative samples and weights on positive samples. |
|
Contrastive Loss function for link prediction tasks. |