GSgnnLinkPredictionModelInterface
- class graphstorm.model.GSgnnLinkPredictionModelInterface
Bases:
objectThe interface for GraphStorm link prediction model.
This interface defines one method:
forward()for training. Link prediction models should inherite this interface and implement this method.- abstract forward(blocks, pos_graph, neg_graph, node_feats, edge_feats, pos_edge_feats=None, neg_edge_feats=None, input_nodes=None)
The forward function for link prediction.
This method is used for training. It takes a list of DGL message flow graphs (MFGs), node features, and edge features of a mini-batch as inputs, and computes the loss of the model in the mini-batch as the return value. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.
Parameters
- blocks: list of DGL MFGs
Sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.
- pos_grapha DGLGraph
The graph that contains the positive edges.
- neg_grapha DGLGraph
The graph that contains the negative edges.
- node_featsdict of Tensors
The input node features of the message passing graph.
- edge_featsdict of Tensors
The input edge features of the message passing graph.
- input_nodes: dict of Tensors
The input nodes of a mini-batch.
Returns
float: The loss of prediction of this mini-batch.