GSgnnLinkPredictionTrainer
- class graphstorm.trainer.GSgnnLinkPredictionTrainer(model, topk_model_to_save=1)
Bases:
GSgnnTrainerTrainer for link prediction tasks.
GSgnnLinkPredictionTraineris a high-level trainer wrapper that can be used directly to train a link prediction model.GSgnnLinkPredictionTrainerdefine two main functions:fit(): performs the training for the model provided to this trainer when the object is initialized, and;eval(): evaluates the provided model against test and validation dataset.
Example
from graphstorm.dataloading import GSgnnLinkPredictionDataLoader from graphstorm.dataset import GSgnnData from graphstorm.model import GSgnnLinkPredictionModel from graphstorm.trainer import GSgnnLinkPredictionTrainer lp_data = GSgnnData("...") target_idx = lp_data.get_edge_train_set([("src_ntype1", "etype1", "dst_ntype1)]) train_loader = GSgnnLinkPredictionDataLoader( lp_data, target_idx, fanout=[10], batch_size=1024, num_negative_edges=10, node_feats="feat", train_task=True) model = GSgnnLinkPredictionModel(alpha_l2norm=0.0) trainer = GSgnnLinkPredictionTrainer(model) trainer.fit(train_loader, num_epochs=2)
Parameters
- modelGSgnnLinkPredictionModelBase
The GNN model for link prediction, which could be a model class inherited from the
GSgnnLinkPredictionModelBase, or a model class that inherits both theGSgnnModelBaseand theGSgnnLinkPredictionModelInterfaceclass.- topk_model_to_saveint
The top K model to be saved based on evaluation results. Default: 1.
- fit(train_loader, num_epochs, val_loader=None, test_loader=None, use_mini_batch_infer=True, save_model_path=None, save_model_frequency=-1, save_perf_results_path=None, edge_mask_for_gnn_embeddings='train_mask', freeze_input_layer_epochs=0, max_grad_norm=None, grad_norm_type=2.0)
Fit function for link prediction.
This function performs the training for the given link prediction model. It iterates over the training batches provided by the
train_loaderto compute the loss, and then performs the backward steps using trainer’s own optimizer.If an evaluator and a validation dataloader are added to this trainer, during training, the trainer will perform model evaluation in three cases:
At the end of each epoch.
At the evaluation frequency (number of iterations) defined in the evaluator.
Before saving a model checkpoint.
Parameters
- train_loader: GSgnnLinkPredictionDataLoader
LinkPrediction dataloader for mini-batch sampling the training set.
- num_epochs: int
The max number of epochs used to train the model.
- val_loader: GSgnnLinkPredictionDataLoader
LinkPrediction dataloader for mini-batch sampling the validation set. Default: None.
- test_loader: GSgnnLinkPredictionDataLoader
LinkPrediction dataloader for mini-batch sampling the test set. Default: None.
- use_mini_batch_infer: bool
Whether to use mini-batch for inference. Default: True.
- save_model_path: str
The path where trained model checkpoints are saved. If is None, will not save model checkpoints. Default: None.
- save_model_frequency: int
The number of iterations to train the model before saving a model checkpoint. Default: -1, meaning only save model after each epoch.
- save_perf_results_path: str
The path of the file where the performance results are saved. Default: None.
- edge_mask_for_gnn_embeddingsstr
The mask that indicates the edges used for computing GNN embeddings for model evaluation. By default, it uses the edges in the training graph to compute GNN embeddings for evaluation. Default: “train_mask”.
- freeze_input_layer_epochs: int
The number of epochs to freeze the input layer from updating trainable parameters. This is commonly used when the input layer contains language models. Default: 0.
- max_grad_norm: float
A value used to clip the gradient, which can enhance training stability. More explanation of this argument can be found in torch.nn.utils.clip_grad_norm_. Default: None.
- grad_norm_type: float
Norm type for the gradient clip. More explanation of this argument can be found in torch.nn.utils.clip_grad_norm_. Default: 2.0.
- eval(model, data, val_loader, test_loader, total_steps, edge_mask_for_gnn_embeddings, use_mini_batch_infer=False)
Do model evaluation using the validation set, or test set if provided.
Parameters
- modelGSgnnLinkPredictionModelBase
The GNN model for link prediction, which could be a model class inherited from the
GSgnnLinkPredictionModelBase, or a model class that inherits both theGSgnnModelBaseand theGSgnnLinkPredictionModelInterface.- dataGSgnnData
The
GSgnnDataassociated with dataloaders.- val_loader: GSgnnLinkPredictionDataLoader
Link prediction dataloader for mini-batch sampling the validation set. Default: None.
- test_loader: GSgnnLinkPredictionDataLoader
Link prediction dataloader for mini-batch sampling the test set. Default: None.
- total_steps: int
The total number of iterations.
- edge_mask_for_gnn_embeddingsstr
The mask that indicates the edges used for computing GNN embeddings for model evaluation.
- use_mini_batch_infer: bool
Whether to use mini-batch for inference. Default: True.
Returns
- val_score: dict
Validation scores of different metrics in the format of {metric: val_score}.