GSgnnLinkPredictionTrainer

class graphstorm.trainer.GSgnnLinkPredictionTrainer(model, topk_model_to_save=1)

Bases: GSgnnTrainer

Trainer for link prediction tasks.

GSgnnLinkPredictionTrainer is a high-level trainer wrapper that can be used directly to train a link prediction model. GSgnnLinkPredictionTrainer define two main functions:

  • fit(): performs the training for the model provided to this trainer when the object is initialized, and;

  • eval(): evaluates the provided model against test and validation dataset.

Example

from graphstorm.dataloading import GSgnnLinkPredictionDataLoader
from graphstorm.dataset import GSgnnData
from graphstorm.model import GSgnnLinkPredictionModel
from graphstorm.trainer import GSgnnLinkPredictionTrainer

lp_data = GSgnnData("...")
target_idx = lp_data.get_edge_train_set([("src_ntype1", "etype1", "dst_ntype1)])
train_loader = GSgnnLinkPredictionDataLoader(
    lp_data, target_idx, fanout=[10], batch_size=1024,
    num_negative_edges=10, node_feats="feat", train_task=True)
model = GSgnnLinkPredictionModel(alpha_l2norm=0.0)

trainer = GSgnnLinkPredictionTrainer(model)

trainer.fit(train_loader, num_epochs=2)

Parameters

modelGSgnnLinkPredictionModelBase

The GNN model for link prediction, which could be a model class inherited from the GSgnnLinkPredictionModelBase, or a model class that inherits both the GSgnnModelBase and the GSgnnLinkPredictionModelInterface class.

topk_model_to_saveint

The top K model to be saved based on evaluation results. Default: 1.

fit(train_loader, num_epochs, val_loader=None, test_loader=None, use_mini_batch_infer=True, save_model_path=None, save_model_frequency=-1, save_perf_results_path=None, edge_mask_for_gnn_embeddings='train_mask', freeze_input_layer_epochs=0, max_grad_norm=None, grad_norm_type=2.0)

Fit function for link prediction.

This function performs the training for the given link prediction model. It iterates over the training batches provided by the train_loader to compute the loss, and then performs the backward steps using trainer’s own optimizer.

If an evaluator and a validation dataloader are added to this trainer, during training, the trainer will perform model evaluation in three cases:

  • At the end of each epoch.

  • At the evaluation frequency (number of iterations) defined in the evaluator.

  • Before saving a model checkpoint.

Parameters

train_loader: GSgnnLinkPredictionDataLoader

LinkPrediction dataloader for mini-batch sampling the training set.

num_epochs: int

The max number of epochs used to train the model.

val_loader: GSgnnLinkPredictionDataLoader

LinkPrediction dataloader for mini-batch sampling the validation set. Default: None.

test_loader: GSgnnLinkPredictionDataLoader

LinkPrediction dataloader for mini-batch sampling the test set. Default: None.

use_mini_batch_infer: bool

Whether to use mini-batch for inference. Default: True.

save_model_path: str

The path where trained model checkpoints are saved. If is None, will not save model checkpoints. Default: None.

save_model_frequency: int

The number of iterations to train the model before saving a model checkpoint. Default: -1, meaning only save model after each epoch.

save_perf_results_path: str

The path of the file where the performance results are saved. Default: None.

edge_mask_for_gnn_embeddingsstr

The mask that indicates the edges used for computing GNN embeddings for model evaluation. By default, it uses the edges in the training graph to compute GNN embeddings for evaluation. Default: “train_mask”.

freeze_input_layer_epochs: int

The number of epochs to freeze the input layer from updating trainable parameters. This is commonly used when the input layer contains language models. Default: 0.

max_grad_norm: float

A value used to clip the gradient, which can enhance training stability. More explanation of this argument can be found in torch.nn.utils.clip_grad_norm_. Default: None.

grad_norm_type: float

Norm type for the gradient clip. More explanation of this argument can be found in torch.nn.utils.clip_grad_norm_. Default: 2.0.

eval(model, data, val_loader, test_loader, total_steps, edge_mask_for_gnn_embeddings, use_mini_batch_infer=False)

Do model evaluation using the validation set, or test set if provided.

Parameters

modelGSgnnLinkPredictionModelBase

The GNN model for link prediction, which could be a model class inherited from the GSgnnLinkPredictionModelBase, or a model class that inherits both the GSgnnModelBase and the GSgnnLinkPredictionModelInterface.

dataGSgnnData

The GSgnnData associated with dataloaders.

val_loader: GSgnnLinkPredictionDataLoader

Link prediction dataloader for mini-batch sampling the validation set. Default: None.

test_loader: GSgnnLinkPredictionDataLoader

Link prediction dataloader for mini-batch sampling the test set. Default: None.

total_steps: int

The total number of iterations.

edge_mask_for_gnn_embeddingsstr

The mask that indicates the edges used for computing GNN embeddings for model evaluation.

use_mini_batch_infer: bool

Whether to use mini-batch for inference. Default: True.

Returns

val_score: dict

Validation scores of different metrics in the format of {metric: val_score}.