GSgnnLinkPredictionModelInterface

class graphstorm.model.GSgnnLinkPredictionModelInterface

Bases: object

The interface for GraphStorm link prediction model.

This interface defines one method: forward() for training. Link prediction models should inherite this interface and implement this method.

abstract forward(blocks, pos_graph, neg_graph, node_feats, edge_feats, pos_edge_feats=None, neg_edge_feats=None, input_nodes=None)

The forward function for link prediction.

This method is used for training. It takes a list of DGL message flow graphs (MFGs), node features, and edge features of a mini-batch as inputs, and computes the loss of the model in the mini-batch as the return value. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

Parameters

blocks: list of DGL MFGs

Sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

pos_grapha DGLGraph

The graph that contains the positive edges.

neg_grapha DGLGraph

The graph that contains the negative edges.

node_featsdict of Tensors

The input node features of the message passing graph.

edge_featsdict of Tensors

The input edge features of the message passing graph.

input_nodes: dict of Tensors

The input nodes of a mini-batch.

Returns

float: The loss of prediction of this mini-batch.