GSgnnNodePredictionInferrer

class graphstorm.inference.GSgnnNodePredictionInferrer(model)

Bases: GSInferrer

Inferrer for node prediction tasks.

GSgnnNodePredictionInferrer defines the infer() method that performs three works:

  • Generate node embeddings and save to disk.

  • Compute inference results for nodes with target node type.

  • (Optional) Evaluate the model performance on a test set if given.

Parameters

model: GSgnnNodeModelBase

The GNN model for node prediction, which could be a model class inherited from the GSgnnNodeModelBase, or a model class that inherits both the GSgnnModelBase and the GSgnnNodeModelInterface class.

infer(loader, save_embed_path, save_prediction_path=None, use_mini_batch_infer=False, node_id_mapping_file=None, return_proba=True, save_embed_format='pytorch')

Do inference.

Parameters

loaderGSNodeDataLoader

Node dataloader for node prediction task.

save_embed_pathstr

The path where the GNN embeddings will be saved.

save_prediction_pathstr

The path where the prediction results will be saved. If is None, will not save the predictions. Default: None.

use_mini_batch_infer: bool

Whether to use mini-batch for inference. Default: False.

node_id_mapping_file: str

Path to the file storing node id mapping generated by the graph partition algorithm. If is None, will not do node ID mapping. Default: None.

return_probabool

Whether to return the predicted results, or only return the argmaxed ones in classification models.

save_embed_formatstr

Specify the data format of saved embeddings. Currently only support PyTorch Tensor. Default: “pytorch”.