GSgnnLinkPredictionModelBase

class graphstorm.model.GSgnnLinkPredictionModelBase

Bases: GSgnnLinkPredictionModelInterface, GSgnnModelBase

The base class for link-prediction GNN

When a user wants to define a link prediction GNN model and train the model in GraphStorm, the model class needs to inherit from this base class. A user needs to implement some basic methods including forward, predict, save_model, restore_model and create_optimizer.

abstract create_optimizer()

Create the optimizer that optimizes the model.

A user who defines a model should also define the optimizer for this model. By using this method, a user can define the optimization algorithm, the learning rate as well as any other hyperparameters.

A model may require multiple optimizers. For example, we should define an optimizer for sparse embeddings and an optimizer for the dense parameters of a GNN model. In this case, a user should use a GSOptimizer to combine these optimizers.

Example:

Case 1: if there is only one optimizer:

def create_optimizer(self):
    # define torch.optim.Optimizer
    return optimizer

Case 2: if there are both dense and sparse optimizers:

def create_optimizer(self):
    dense = [dense_opt] # define torch.optim.Optimizer
    sparse = [sparse_opt] # define dgl sparse Optimizer
    optimizer = GSOptimizer(dense_opts=dense,
                            lm_opts=None,
                            sparse_opts=sparse)
    return optimizer
abstract forward(blocks, pos_graph, neg_graph, node_feats, edge_feats, pos_edge_feats=None, neg_edge_feats=None, input_nodes=None)

The forward function for link prediction.

This method is used for training. It takes a mini-batch, including the graph structure, node features and edge features and computes the loss of the model in the mini-batch.

Parameters

blockslist of DGLBlock

The message passing graph for computing GNN embeddings.

pos_grapha DGLGraph

The graph that contains the positive edges.

neg_grapha DGLGraph

The graph that contains the negative edges.

node_featsdict of Tensors

The input node features of the message passing graphs.

edge_featsdict of Tensors

The input edge features of the message passing graphs.

input_nodes: dict of Tensors

The input nodes of a mini-batch.

Returns

The loss of prediction.

restore_model(restore_model_path, model_layer_to_load=None)

Restore saved checkpoints of a GNN model.

A user who implement this method should load the parameters of the GNN model. This method does not need to load the optimizer state.

Examples

Load a model from “/tmp/checkpoints”.

# CustomGSgnnModel is a child class of GSgnnModelBase
model = CustomGSgnnModel()

# Restore model parameters from "/tmp/checkpoints"
model.restore_model("/tmp/checkpoints")

Parameters

restore_model_pathstr

The path where we can restore the model.

model_layer_to_load: list of str

list of model layers to load. Supported layers include ‘gnn’, ‘embed’, ‘decoder’

save_model(model_path)

Save the GNN model.

When saving a GNN model, we need to save the dense parameters and sparse parameters.

Examples

Save a model into “/tmp/checkpoints”.

# CustomGSgnnModel is a child class of GSgnnModelBase
model = CustomGSgnnModel()

# Model parameters will be saved into "/tmp/checkpoints"
model.save_model("/tmp/checkpoints")

Parameters

model_pathstr

The path where all model parameters and optimizer states are saved.