GSgnnEdgeModelBase

class graphstorm.model.GSgnnEdgeModelBase

Bases: GSgnnEdgeModelInterface, GSgnnModelBase

The base class for edge-prediction GNN

When a user wants to define an edge prediction GNN model and train the model in GraphStorm, the model class needs to inherit from this base class. A user needs to implement some basic methods including forward, predict, save_model, restore_model and create_optimizer.

abstract create_optimizer()

Create the optimizer that optimizes the model.

A user who defines a model should also define the optimizer for this model. By using this method, a user can define the optimization algorithm, the learning rate as well as any other hyperparameters.

A model may require multiple optimizers. For example, we should define an optimizer for sparse embeddings and an optimizer for the dense parameters of a GNN model. In this case, a user should use a GSOptimizer to combine these optimizers.

Example:

Case 1: if there is only one optimizer:

def create_optimizer(self):
    # define torch.optim.Optimizer
    return optimizer

Case 2: if there are both dense and sparse optimizers:

def create_optimizer(self):
    dense = [dense_opt] # define torch.optim.Optimizer
    sparse = [sparse_opt] # define dgl sparse Optimizer
    optimizer = GSOptimizer(dense_opts=dense,
                            lm_opts=None,
                            sparse_opts=sparse)
    return optimizer
abstract forward(blocks, target_edges, node_feats, edge_feats, target_edge_feats, labels, input_nodes=None)

The forward function for edge prediction.

This method is used for training. It takes a mini-batch, including the graph structure, node features, edge features and edge labels and computes the loss of the model in the mini-batch.

Parameters

blockslist of DGLBlock

The message passing graph for computing GNN embeddings.

target_edgesa DGLGraph

The graph where we store target edges to run edge classification.

node_featsdict of Tensors

The input node features of the message passing graphs.

edge_featsdict of Tensors

The input edge features of the message passing graphs.

target_edge_feats: dict of Tensors

The edge features of target_edges

labels: dict of Tensor

The labels of the predicted edges.

input_nodes: dict of Tensors

The input nodes of a mini-batch.

Returns

The loss of prediction.

abstract predict(blocks, target_edges, node_feats, edge_feats, target_edge_feats, input_nodes, return_proba)

Make prediction on the edges.

Parameters

blockslist of DGLBlock

The message passing graph for computing GNN embeddings.

target_edgesa DGLGraph

The graph where we store target edges to run edge classification.

node_featsdict of Tensors

The node features of the message passing graphs.

edge_featsdict of Tensors

The edge features of the message passing graphs.

target_edge_feats: dict of Tensors

The edge features of target_edges

input_nodes: dict of Tensors

The input nodes of a mini-batch.

return_probabool

Whether or not to return all the predicted results or only the maximum one

Returns

Tensor or dict of Tensor:

the prediction results. Return all the results when return_proba is true otherwise return the maximum value.

restore_model(restore_model_path, model_layer_to_load=None)

Restore saved checkpoints of a GNN model.

A user who implement this method should load the parameters of the GNN model. This method does not need to load the optimizer state.

Examples

Load a model from “/tmp/checkpoints”.

# CustomGSgnnModel is a child class of GSgnnModelBase
model = CustomGSgnnModel()

# Restore model parameters from "/tmp/checkpoints"
model.restore_model("/tmp/checkpoints")

Parameters

restore_model_pathstr

The path where we can restore the model.

model_layer_to_load: list of str

list of model layers to load. Supported layers include ‘gnn’, ‘embed’, ‘decoder’

save_model(model_path)

Save the GNN model.

When saving a GNN model, we need to save the dense parameters and sparse parameters.

Examples

Save a model into “/tmp/checkpoints”.

# CustomGSgnnModel is a child class of GSgnnModelBase
model = CustomGSgnnModel()

# Model parameters will be saved into "/tmp/checkpoints"
model.save_model("/tmp/checkpoints")

Parameters

model_pathstr

The path where all model parameters and optimizer states are saved.