GSgnnEdgePredictionInferrer

class graphstorm.inference.GSgnnEdgePredictionInferrer(model)

Bases: GSInferrer

Inferrer for edge prediction tasks.

GSgnnEdgePredictionInferrer defines the infer() method that performs three works:

  • Generate node embeddings and save to disk.

  • Compute inference results for edges with target edge type.

  • (Optional) Evaluate the model performance on a test set if given.

Parameters

modelGSgnnEdgeModelBase

The GNN model for edge prediction, which could be a model class inherited from the GSgnnEdgeModelBase, or a model class that inherits both the GSgnnModelBase and the GSgnnEdgeModelInterface class.

infer(loader, save_embed_path, save_prediction_path=None, use_mini_batch_infer=False, node_id_mapping_file=None, edge_id_mapping_file=None, return_proba=True, save_embed_format='pytorch')

Do inference.

Parameters

loaderGSEdgeDataLoader

Edge dataloader for edge prediction task.

save_embed_pathstr

The path where the GNN embeddings will be saved.

save_prediction_pathstr

The path where the prediction results will be saved. If None, will not save the predictions. Default: None.

use_mini_batch_infer: bool

Whether to use mini-batch for inference. Default: False.

node_id_mapping_file: str

Path to the file storing node id mapping generated by the graph partition algorithm. If is None, will not do node ID mapping. Default: None.

return_probabool

Whether to return the predicted results, or only return the argmax ones in classification models.

save_embed_formatstr

Specify the data format of saved embeddings. Currently only support PyTorch Tensor. Default: “pytorch”.