GSgnnLinkPredictionModelBase
- class graphstorm.model.GSgnnLinkPredictionModelBase
Bases:
GSgnnLinkPredictionModelInterface,GSgnnModelBaseThe base class for link-prediction GNN
When a user wants to define a link prediction GNN model and train the model in GraphStorm, the model class needs to inherit from this base class. A user needs to implement some basic methods including forward, predict, save_model, restore_model and create_optimizer.
- abstract create_optimizer()
Create the optimizer that optimizes the model.
A user who defines a model should also define the optimizer for this model. By using this method, a user can define the optimization algorithm, the learning rate as well as any other hyperparameters.
A model may require multiple optimizers. For example, we should define an optimizer for sparse embeddings and an optimizer for the dense parameters of a GNN model. In this case, a user should use a GSOptimizer to combine these optimizers.
Example:
Case 1: if there is only one optimizer:
def create_optimizer(self): # define torch.optim.Optimizer return optimizer
Case 2: if there are both dense and sparse optimizers:
def create_optimizer(self): dense = [dense_opt] # define torch.optim.Optimizer sparse = [sparse_opt] # define dgl sparse Optimizer optimizer = GSOptimizer(dense_opts=dense, lm_opts=None, sparse_opts=sparse) return optimizer
- abstract forward(blocks, pos_graph, neg_graph, node_feats, edge_feats, pos_edge_feats=None, neg_edge_feats=None, input_nodes=None)
The forward function for link prediction.
This method is used for training. It takes a mini-batch, including the graph structure, node features and edge features and computes the loss of the model in the mini-batch.
Parameters
- blockslist of DGLBlock
The message passing graph for computing GNN embeddings.
- pos_grapha DGLGraph
The graph that contains the positive edges.
- neg_grapha DGLGraph
The graph that contains the negative edges.
- node_featsdict of Tensors
The input node features of the message passing graphs.
- edge_featsdict of Tensors
The input edge features of the message passing graphs.
- input_nodes: dict of Tensors
The input nodes of a mini-batch.
Returns
The loss of prediction.
- restore_model(restore_model_path, model_layer_to_load=None)
Restore saved checkpoints of a GNN model.
A user who implement this method should load the parameters of the GNN model. This method does not need to load the optimizer state.
Examples
Load a model from “/tmp/checkpoints”.
# CustomGSgnnModel is a child class of GSgnnModelBase model = CustomGSgnnModel() # Restore model parameters from "/tmp/checkpoints" model.restore_model("/tmp/checkpoints")
Parameters
- restore_model_pathstr
The path where we can restore the model.
- model_layer_to_load: list of str
list of model layers to load. Supported layers include ‘gnn’, ‘embed’, ‘decoder’
- save_model(model_path)
Save the GNN model.
When saving a GNN model, we need to save the dense parameters and sparse parameters.
Examples
Save a model into “/tmp/checkpoints”.
# CustomGSgnnModel is a child class of GSgnnModelBase model = CustomGSgnnModel() # Model parameters will be saved into "/tmp/checkpoints" model.save_model("/tmp/checkpoints")
Parameters
- model_pathstr
The path where all model parameters and optimizer states are saved.