GSgnnLinkPredictionInferrer

class graphstorm.inference.GSgnnLinkPredictionInferrer(model)

Bases: GSInferrer

Inferrer for link prediction tasks.

GSgnnLinkPredictionInferrer defines the infer() method that performs two works:

  • Generate node embeddings and save to disk.

  • (Optional) Evaluate the model performance on a test set if given.

Parameters

modelGSgnnLinkPredictionModelBase

The GNN model for link prediction, which could be a model class inherited from the GSgnnLinkPredictionModelBase, or a model class that inherits both the GSgnnModelBase and the GSgnnLinkPredictionModelInterface class.

infer(data, loader, save_embed_path, edge_mask_for_gnn_embeddings='train_mask', use_mini_batch_infer=False, node_id_mapping_file=None, save_embed_format='pytorch', infer_batch_size=1024)

Do inference.

Parameters

data: GSgnnData

The GraphStorm dataset

loaderGSgnnLinkPredictionTestDataLoader

Link prediction dataloader for link prediction task.

save_embed_pathstr

The path where the GNN embeddings will be saved.

edge_mask_for_gnn_embeddingsstr

The mask that indicates the edges used for computing GNN embeddings for model evaluation. By default, it uses the edges in the training graph to compute GNN embeddings for evaluation. Default: “train_mask”.

use_mini_batch_infer: bool

Whether to use mini-batch for inference. Default: False.

node_id_mapping_file: str

Path to the file storing node id mapping generated by the graph partition algorithm. If is None, will not do node ID mapping. Default: None.

save_embed_formatstr

Specify the data format of saved embeddings. Currently only support PyTorch Tensor. Default: “pytorch”.

infer_batch_size: int

The inference batch size when computing node embeddings with mini-batch inference.