GSgnnEdgeModelInterface
- class graphstorm.model.GSgnnEdgeModelInterface
Bases:
objectThe interface for GraphStorm edge prediction model.
This interface defines two main methods:
forward()for training andpredict()for inference. Edge GNN models should inherite this interface and implement the two methods.- abstract forward(blocks, target_edges, node_feats, edge_feats, target_edge_feats, labels, input_nodes=None)
The forward function for edge prediction.
This method is used for training. It takes a list of DGL message flow graphs (MFGs), node features, edge features, and edge labels of a mini-batch as inputs, and computes the loss of the model in the mini-batch as the return value. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.
Parameters
- blocks: list of DGL MFGs
Sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.
- target_edgesa DGLGraph
The graph that stores target edges to run edge prediction.
- node_featsdict of Tensors
The input node features of the message passing graph.
- edge_featsdict of Tensors
The input edge features of the message passing graph.
- target_edge_feats: dict of Tensors
The edge features of target edges
- labels: dict of Tensor
The labels of the target edges.
- input_nodes: dict of Tensors
The input nodes of the mini-batch.
Returns
float: The loss of prediction of this mini-batch.
- abstract predict(blocks, target_edges, node_feats, edge_feats, target_edge_feats, input_nodes, return_proba)
Make prediction on the target edges.
Parameters
- blocks: list of DGL MFGs
Sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.
- target_edgesa DGLGraph
The graph that stores target edges to run edge prediction.
- node_featsdict of Tensors
The node features of the message passing graph.
- edge_featsdict of Tensors
The edge features of the message passing graph.
- target_edge_feats: dict of Tensors
The edge features of target edges.
- input_nodes: dict of Tensors
The input nodes of the mini-batch.
- return_probabool
Whether to return the predicted results, or only return the argmaxed ones in classification models.
Returns
- Tensor, or dict of Tensor:
GNN prediction results. Return results of all dimensions when
return_probais True, otherwise return the argmaxed results.