GSgnnNodePredictionTrainer

class graphstorm.trainer.GSgnnNodePredictionTrainer(model, topk_model_to_save=1)

Bases: GSgnnTrainer

Trainer for node prediction tasks.

GSgnnNodePredictionTrainer is used to train models for node prediction tasks, such as node classification and node regression. GSgnnNodePredictionTrainer defines two main functions:

  • fit(): performs the training for the model provided to this trainer when the object is initialized, and;

  • eval(): evaluates the provided model against test and validation dataset.

Example

from graphstorm.dataloading import GSgnnNodeDataLoader
from graphstorm.dataset import GSgnnData
from graphstorm.model import GSgnnNodeModel
from graphstorm.trainer import GSgnnNodePredictionTrainer

np_data = GSgnnData("...")
target_idx = np_data.get_node_train_set('ntype1')
train_loader = GSgnnNodeDataLoader(
    np_dataset, target_idx, fanout=[10], batch_size=1024,
    label_field="label", node_feats="feat", train_task=True)
model = GSgnnNodeModel(alpha_l2norm=0.0)

trainer = GSgnnNodePredictionTrainer(model)

trainer.fit(train_loader, num_epochs=2)

Parameters

model: GSgnnNodeModelBase

The GNN model for node prediction, which could be a model class inherited from the GSgnnNodeModelBase, or a model class that inherits both the GSgnnModelBase and the GSgnnNodeModelInterface.

topk_model_to_save: int

The top K model to be saved based on evaluation results. Default: 1.

fit(train_loader, num_epochs, val_loader=None, test_loader=None, use_mini_batch_infer=True, save_model_path=None, save_model_frequency=-1, save_perf_results_path=None, freeze_input_layer_epochs=0, max_grad_norm=None, grad_norm_type=2.0)

Fit function for node prediction training.

This function performs the training for the given node prediction model. It iterates over the training batches provided by the train_loader to compute the loss, and then performs the backward steps using trainer’s own optimizer.

If an evaluator and a validation dataloader are added to this trainer, during training, the trainer will perform model evaluation in three cases:

  • At the end of each epoch.

  • At the evaluation frequency (number of iterations) defined in the evaluator.

  • Before saving a model checkpoint.

Parameters

train_loader: GSgnnNodeDataLoader

Node dataloader for mini-batch sampling the training set.

num_epochs: int

The max number of epochs used to train the model.

val_loader: GSgnnNodeDataLoader

Node dataloader for mini-batch sampling the validation set. Default: None.

test_loader: GSgnnNodeDataLoader

Node dataloader for mini-batch sampling the test set. Default: None.

use_mini_batch_infer: bool

Whether to use mini-batch for inference. Default: True.

save_model_path: str

The path where trained model checkpoints are saved. If is None, will not save model checkpoints. Default: None.

save_model_frequency: int

The number of iterations to train the model before saving a model checkpoint. Default: -1, meaning only save model after each epoch.

save_perf_results_path: str

The path of the file where the performance results are saved. Default: None.

freeze_input_layer_epochs: int

The number of epochs to freeze the input layer from updating trainable parameters. This is commonly used when the input layer contains language models. Default: 0.

max_grad_norm: float

A value used to clip the gradient, which can enhance training stability. More explanation of this argument can be found in torch.nn.utils.clip_grad_norm_. Default: None.

grad_norm_type: float

Norm type for the gradient clip. More explanation of this argument can be found in torch.nn.utils.clip_grad_norm_. Default: 2.0.

eval(model, val_loader, test_loader, use_mini_batch_infer, total_steps, return_proba=True)

Do model evaluation using the validation set, or test set if provided.

Parameters

model: GSgnnNodeModelBase

The GNN model for node prediction, which could be a model class inherited from the GSgnnNodeModelBase, or a model class that inherits both the GSgnnModelBase and the GSgnnNodeModelInterface.

val_loader: GSgnnNodeDataLoader

Node dataloader for mini-batch sampling the validation set. Default: None.

test_loader: GSgnnNodeDataLoader

Node dataloader for mini-batch sampling the test set. Default: None.

use_mini_batch_infer: bool

Whether to use mini-batch for inference. Default: True.

total_steps: int

The total number of iterations.

return_proba: bool

Whether to return the prediction results or the argmax results for classification tasks.

Returns

val_score: dict

Validation scores of differnet metrics in the format of {metric: val_score}.