GSgnnEdgeModelInterface

class graphstorm.model.GSgnnEdgeModelInterface

Bases: object

The interface for GraphStorm edge prediction model.

This interface defines two main methods: forward() for training and predict() for inference. Edge GNN models should inherite this interface and implement the two methods.

abstract forward(blocks, target_edges, node_feats, edge_feats, target_edge_feats, labels, input_nodes=None)

The forward function for edge prediction.

This method is used for training. It takes a list of DGL message flow graphs (MFGs), node features, edge features, and edge labels of a mini-batch as inputs, and computes the loss of the model in the mini-batch as the return value. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

Parameters

blocks: list of DGL MFGs

Sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

target_edgesa DGLGraph

The graph that stores target edges to run edge prediction.

node_featsdict of Tensors

The input node features of the message passing graph.

edge_featsdict of Tensors

The input edge features of the message passing graph.

target_edge_feats: dict of Tensors

The edge features of target edges

labels: dict of Tensor

The labels of the target edges.

input_nodes: dict of Tensors

The input nodes of the mini-batch.

Returns

float: The loss of prediction of this mini-batch.

abstract predict(blocks, target_edges, node_feats, edge_feats, target_edge_feats, input_nodes, return_proba)

Make prediction on the target edges.

Parameters

blocks: list of DGL MFGs

Sampled subgraph in the list of DGL message flow graph (MFG) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

target_edgesa DGLGraph

The graph that stores target edges to run edge prediction.

node_featsdict of Tensors

The node features of the message passing graph.

edge_featsdict of Tensors

The edge features of the message passing graph.

target_edge_feats: dict of Tensors

The edge features of target edges.

input_nodes: dict of Tensors

The input nodes of the mini-batch.

return_probabool

Whether to return the predicted results, or only return the argmaxed ones in classification models.

Returns

Tensor, or dict of Tensor:

GNN prediction results. Return results of all dimensions when return_proba is True, otherwise return the argmaxed results.