GSgnnNodeModelInterface

class graphstorm.model.GSgnnNodeModelInterface

Bases: object

The interface for GraphStorm node prediction model.

This interface defines two main methods: forward() for training and predict() for inference. Node models should inherite this interface and implement the two methods.

abstract forward(blocks, node_feats, edge_feats, labels, input_nodes=None)

The forward function for node prediction.

This method is used for training. It takes a list of DGL message flow graphs (MFGs), node features, edge features, and node labels of a mini-batch as inputs, and computes the loss of the model in the mini-batch as the return value.

Parameters

blocks: list of DGL MFGs

Sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

node_featsdict of Tensors

The input node features of the message passing graph.

edge_featsdict of Tensors

The input edge features of the message passing graph.

labels: dict of Tensor

The labels of the predicted nodes.

input_nodes: dict of Tensors

The input nodes of the mini-batch.

Returns

float: The loss of prediction of this mini-batch.

abstract predict(blocks, node_feats, edge_feats, input_nodes, return_proba)

Make prediction on the input nodes.

This method is used for inference. It takes a list of DGL message flow graphs (MFGs), node features, edge features, and input node of a mini-batch as input, and computes the predictions of the input nodes. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

Parameters

blocks: list of DGL MFGs

Sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

node_featsdict of Tensors

The node features of the message passing graph.

edge_featsdict of Tensors

The edge features of the message passing graph.

input_nodes: dict of Tensors

The input nodes of the mini-batch.

return_probabool

Whether to return the predicted results, or only return the argmaxed ones in classification models.

Returns

Tensor, or dict of Tensor:

Prediction results. Return results of all dimensions when return_proba is True, otherwise return the argmaxed results.

Tensor, or dict of Tensor:

The GNN embeddings.