GSgnnEdgeModelBase
- class graphstorm.model.GSgnnEdgeModelBase
Bases:
GSgnnEdgeModelInterface,GSgnnModelBaseThe base class for edge-prediction GNN
When a user wants to define an edge prediction GNN model and train the model in GraphStorm, the model class needs to inherit from this base class. A user needs to implement some basic methods including forward, predict, save_model, restore_model and create_optimizer.
- abstract create_optimizer()
Create the optimizer that optimizes the model.
A user who defines a model should also define the optimizer for this model. By using this method, a user can define the optimization algorithm, the learning rate as well as any other hyperparameters.
A model may require multiple optimizers. For example, we should define an optimizer for sparse embeddings and an optimizer for the dense parameters of a GNN model. In this case, a user should use a GSOptimizer to combine these optimizers.
Example:
Case 1: if there is only one optimizer:
def create_optimizer(self): # define torch.optim.Optimizer return optimizer
Case 2: if there are both dense and sparse optimizers:
def create_optimizer(self): dense = [dense_opt] # define torch.optim.Optimizer sparse = [sparse_opt] # define dgl sparse Optimizer optimizer = GSOptimizer(dense_opts=dense, lm_opts=None, sparse_opts=sparse) return optimizer
- abstract forward(blocks, target_edges, node_feats, edge_feats, target_edge_feats, labels, input_nodes=None)
The forward function for edge prediction.
This method is used for training. It takes a mini-batch, including the graph structure, node features, edge features and edge labels and computes the loss of the model in the mini-batch.
Parameters
- blockslist of DGLBlock
The message passing graph for computing GNN embeddings.
- target_edgesa DGLGraph
The graph where we store target edges to run edge classification.
- node_featsdict of Tensors
The input node features of the message passing graphs.
- edge_featsdict of Tensors
The input edge features of the message passing graphs.
- target_edge_feats: dict of Tensors
The edge features of target_edges
- labels: dict of Tensor
The labels of the predicted edges.
- input_nodes: dict of Tensors
The input nodes of a mini-batch.
Returns
The loss of prediction.
- abstract predict(blocks, target_edges, node_feats, edge_feats, target_edge_feats, input_nodes, return_proba)
Make prediction on the edges.
Parameters
- blockslist of DGLBlock
The message passing graph for computing GNN embeddings.
- target_edgesa DGLGraph
The graph where we store target edges to run edge classification.
- node_featsdict of Tensors
The node features of the message passing graphs.
- edge_featsdict of Tensors
The edge features of the message passing graphs.
- target_edge_feats: dict of Tensors
The edge features of target_edges
- input_nodes: dict of Tensors
The input nodes of a mini-batch.
- return_probabool
Whether or not to return all the predicted results or only the maximum one
Returns
- Tensor or dict of Tensor:
the prediction results. Return all the results when return_proba is true otherwise return the maximum value.
- restore_model(restore_model_path, model_layer_to_load=None)
Restore saved checkpoints of a GNN model.
A user who implement this method should load the parameters of the GNN model. This method does not need to load the optimizer state.
Examples
Load a model from “/tmp/checkpoints”.
# CustomGSgnnModel is a child class of GSgnnModelBase model = CustomGSgnnModel() # Restore model parameters from "/tmp/checkpoints" model.restore_model("/tmp/checkpoints")
Parameters
- restore_model_pathstr
The path where we can restore the model.
- model_layer_to_load: list of str
list of model layers to load. Supported layers include ‘gnn’, ‘embed’, ‘decoder’
- save_model(model_path)
Save the GNN model.
When saving a GNN model, we need to save the dense parameters and sparse parameters.
Examples
Save a model into “/tmp/checkpoints”.
# CustomGSgnnModel is a child class of GSgnnModelBase model = CustomGSgnnModel() # Model parameters will be saved into "/tmp/checkpoints" model.save_model("/tmp/checkpoints")
Parameters
- model_pathstr
The path where all model parameters and optimizer states are saved.