GSgnnNodeModelInterface
- class graphstorm.model.GSgnnNodeModelInterface
Bases:
objectThe interface for GraphStorm node prediction model.
This interface defines two main methods:
forward()for training andpredict()for inference. Node models should inherite this interface and implement the two methods.- abstract forward(blocks, node_feats, edge_feats, labels, input_nodes=None)
The forward function for node prediction.
This method is used for training. It takes a list of DGL message flow graphs (MFGs), node features, edge features, and node labels of a mini-batch as inputs, and computes the loss of the model in the mini-batch as the return value.
Parameters
- blocks: list of DGL MFGs
Sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.
- node_featsdict of Tensors
The input node features of the message passing graph.
- edge_featsdict of Tensors
The input edge features of the message passing graph.
- labels: dict of Tensor
The labels of the predicted nodes.
- input_nodes: dict of Tensors
The input nodes of the mini-batch.
Returns
float: The loss of prediction of this mini-batch.
- abstract predict(blocks, node_feats, edge_feats, input_nodes, return_proba)
Make prediction on the input nodes.
This method is used for inference. It takes a list of DGL message flow graphs (MFGs), node features, edge features, and input node of a mini-batch as input, and computes the predictions of the input nodes. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.
Parameters
- blocks: list of DGL MFGs
Sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.
- node_featsdict of Tensors
The node features of the message passing graph.
- edge_featsdict of Tensors
The edge features of the message passing graph.
- input_nodes: dict of Tensors
The input nodes of the mini-batch.
- return_probabool
Whether to return the predicted results, or only return the argmaxed ones in classification models.
Returns
- Tensor, or dict of Tensor:
Prediction results. Return results of all dimensions when
return_probais True, otherwise return the argmaxed results.- Tensor, or dict of Tensor:
The GNN embeddings.