GSgnnNodePredictionTrainer
- class graphstorm.trainer.GSgnnNodePredictionTrainer(model, topk_model_to_save=1)
Bases:
GSgnnTrainerTrainer for node prediction tasks.
GSgnnNodePredictionTraineris used to train models for node prediction tasks, such as node classification and node regression.GSgnnNodePredictionTrainerdefines two main functions:fit(): performs the training for the model provided to this trainer when the object is initialized, and;eval(): evaluates the provided model against test and validation dataset.
Example
from graphstorm.dataloading import GSgnnNodeDataLoader from graphstorm.dataset import GSgnnData from graphstorm.model import GSgnnNodeModel from graphstorm.trainer import GSgnnNodePredictionTrainer np_data = GSgnnData("...") target_idx = np_data.get_node_train_set('ntype1') train_loader = GSgnnNodeDataLoader( np_dataset, target_idx, fanout=[10], batch_size=1024, label_field="label", node_feats="feat", train_task=True) model = GSgnnNodeModel(alpha_l2norm=0.0) trainer = GSgnnNodePredictionTrainer(model) trainer.fit(train_loader, num_epochs=2)
Parameters
- model: GSgnnNodeModelBase
The GNN model for node prediction, which could be a model class inherited from the
GSgnnNodeModelBase, or a model class that inherits both theGSgnnModelBaseand theGSgnnNodeModelInterface.- topk_model_to_save: int
The top K model to be saved based on evaluation results. Default: 1.
- fit(train_loader, num_epochs, val_loader=None, test_loader=None, use_mini_batch_infer=True, save_model_path=None, save_model_frequency=-1, save_perf_results_path=None, freeze_input_layer_epochs=0, max_grad_norm=None, grad_norm_type=2.0)
Fit function for node prediction training.
This function performs the training for the given node prediction model. It iterates over the training batches provided by the
train_loaderto compute the loss, and then performs the backward steps using trainer’s own optimizer.If an evaluator and a validation dataloader are added to this trainer, during training, the trainer will perform model evaluation in three cases:
At the end of each epoch.
At the evaluation frequency (number of iterations) defined in the evaluator.
Before saving a model checkpoint.
Parameters
- train_loader: GSgnnNodeDataLoader
Node dataloader for mini-batch sampling the training set.
- num_epochs: int
The max number of epochs used to train the model.
- val_loader: GSgnnNodeDataLoader
Node dataloader for mini-batch sampling the validation set. Default: None.
- test_loader: GSgnnNodeDataLoader
Node dataloader for mini-batch sampling the test set. Default: None.
- use_mini_batch_infer: bool
Whether to use mini-batch for inference. Default: True.
- save_model_path: str
The path where trained model checkpoints are saved. If is None, will not save model checkpoints. Default: None.
- save_model_frequency: int
The number of iterations to train the model before saving a model checkpoint. Default: -1, meaning only save model after each epoch.
- save_perf_results_path: str
The path of the file where the performance results are saved. Default: None.
- freeze_input_layer_epochs: int
The number of epochs to freeze the input layer from updating trainable parameters. This is commonly used when the input layer contains language models. Default: 0.
- max_grad_norm: float
A value used to clip the gradient, which can enhance training stability. More explanation of this argument can be found in torch.nn.utils.clip_grad_norm_. Default: None.
- grad_norm_type: float
Norm type for the gradient clip. More explanation of this argument can be found in torch.nn.utils.clip_grad_norm_. Default: 2.0.
- eval(model, val_loader, test_loader, use_mini_batch_infer, total_steps, return_proba=True)
Do model evaluation using the validation set, or test set if provided.
Parameters
- model: GSgnnNodeModelBase
The GNN model for node prediction, which could be a model class inherited from the
GSgnnNodeModelBase, or a model class that inherits both theGSgnnModelBaseand theGSgnnNodeModelInterface.- val_loader: GSgnnNodeDataLoader
Node dataloader for mini-batch sampling the validation set. Default: None.
- test_loader: GSgnnNodeDataLoader
Node dataloader for mini-batch sampling the test set. Default: None.
- use_mini_batch_infer: bool
Whether to use mini-batch for inference. Default: True.
- total_steps: int
The total number of iterations.
- return_proba: bool
Whether to return the prediction results or the argmax results for classification tasks.
Returns
- val_score: dict
Validation scores of differnet metrics in the format of {metric: val_score}.