GSgnnModelBase
- class graphstorm.model.GSgnnModelBase(*args, **kwargs)
Bases:
ModuleGraphStorm GNN model base class.
Any GNN model trained by GraphStorm should inherit from this class. It contains some abstract methods that should be defined by the inherited classes. It also provides some utility methods.
- abstract restore_dense_model(restore_model_path, model_layer_to_load=None)
Restore dense models, e.g., GNN Encoders, Decoders, etc.
All model parameters except for learnable node embeddings, i.e.,
dgl.distributed.DistEmbedding, are restored by this function. This fuction will go though all the model layers and load the corresponding parameters fromrestore_model_path.In some cases, users can choose which model layer(s) to load by setting
model_layer_to_load.model_layer_to_loadis designed to indicate the names of model layer(s) that should be restored.Example Implementation:
The code below provides examplary implementation of this abstract method.
To restore model parameters for a model with all three layers of a GraphStorm GNN model, including an input layer, a GNN layer and a decoder layer:
# suppose we are going to load all three layers. input_encoder = self.input_encoder gnn_model = self.gnn_model decoder = self.decoder checkpoint = th.load(os.path.join(model_path, 'model.bin'), map_location='cpu', weights_only=True) assert 'gnn' in checkpoint assert 'input' in checkpoint assert 'decoder' in checkpoint input_encoder.load_state_dict(checkpoint['input'], strict=False) gnn_model.load_state_dict(checkpoint['gnn']) decoder.load_state_dict(checkpoint['decoder'])
Parameters
- restore_model_pathstr
The path where the model was stored.
- model_layer_to_load: list of str
List of model layers to load. This arguement is used to indicate which model layer(s) are going to be restored from the model checkpoint. Default: None.
- abstract restore_sparse_model(restore_model_path)
Restore sparse models, e.g., learnable node embeddings.
Learnable node embeddings are restored by this function.
Example Implementation:
The code below provides examplary implementation of this abstract method.
To load sparse model parameters for a node_input_encoder:
from graphstorm.model.utils import load_sparse_emb for ntype, sparse_emb in sparse_embeds.items(): load_sparse_emb(sparse_emb, os.path.join(model_path, ntype))
Parameters
- restore_model_pathstr
The path where the model was stored.
- abstract save_dense_model(model_path)
Save dense models, e.g., GNN Encoders, Decoders, etc.
All model parameters except for learnable node embeddings, i.e.,
dgl.distributed.DistEmbedding, are saved by this function. This fuction should go though all model layers and save the correspoinding parameters undermodel_path.Example Implementation:
The code below provides an examplary implementation of this abstract method.
# This function is only called by rank 0 input_encoder = self.input_encoder gnn_model = self.gnn_model decoder = self.decoder model_states = {} model_states['gnn'] = gnn_model.state_dict() model_states['input'] = input_encoder.state_dict() model_states['decoder'] = decoder.state_dict() os.makedirs(model_path, exist_ok=True) # mode 767 means rwx-rw-rwx: os.chmod(model_path, 0o767) th.save(model_states, os.path.join(model_path, 'model.bin'))
Parameters
- model_pathstr
The path where all model parameters and optimizer states will be saved.
- abstract save_sparse_model(model_path)
Save sparse models, e.g., learnable node embeddings.
Learnable node embeddings are saved by this function. Saving learnable node embeddings only works when 1) the training task is run on a single machine or 2) the training task is running on a distributed environment with a shared file system.
Example Implementation:
The code below provides an examplary implementation of this abstract method.
The implementation of save_sparse_model usually includes two steps:
Step 1: Create a path to save the learnable node embeddings.
from graphstorm.model.util import create_sparse_emb_path for ntype, sparse_emb in sparse_embeds.items(): create_sparse_emb_path(model_path, ntype) # make sure rank 0 creates the folder and change permission first
Step 2: Save learnable node embeddings.
from graphstorm.model.utils import save_sparse_emb for ntype, sparse_emb in sparse_embeds.items(): save_sparse_emb(model_path, sparse_emb, ntype)
Parameters
- model_pathstr
The path where all model sparse parameters will be saved.
- restore_model(restore_model_path, model_layer_to_load=None)
Restore saved checkpoints of a GNN model.
Users who want to overwrite this method should load the parameters of the GNN model. This method does not need to load the optimizer state.
Examples
Load a model from “/tmp/checkpoints”.
# CustomGSgnnModel is a child class of GSgnnModelBase model = CustomGSgnnModel() # Restore model parameters from "/tmp/checkpoints" model.restore_model("/tmp/checkpoints")
Parameters
- restore_model_pathstr
The path where the model was stored.
- model_layer_to_load: list of str
list of model layers to load. Supported layers include: “embed”, “gnn”, “decoder”.
- save_model(model_path)
Save a trained model.
When saving a model, need to save both the dense parameters and sparse parameters.
Examples
Save a model into “/tmp/checkpoints”.
# CustomGSgnnModel is a child class of GSgnnModelBase model = CustomGSgnnModel() # Model parameters will be saved into "/tmp/checkpoints" model.save_model("/tmp/checkpoints")
Parameters
- model_pathstr
The path where all model parameters and optimizer states will be saved.
- abstract create_optimizer()
Create the optimizer that optimizes the model.
Users who want to customize a model should define a optimizer for this model. By using this method, users can define their customized optimization algorithm, the learning rate as well as any other hyperparameters.
A model may require multiple optimizers. For example, we should define an optimizer for sparse embeddings and an optimizer for the dense parameters of a GNN model. In this case, a user should use a GSOptimizer to combine these optimizers.
Example:
Case 1: if there is only one optimizer:
def create_optimizer(self): # define torch.optim.Optimizer return optimizer
Case 2: if there are both dense and sparse optimizers:
def create_optimizer(self): dense = [dense_opt] # define torch.optim.Optimizer sparse = [sparse_opt] # define dgl sparse Optimizer optimizer = GSOptimizer(dense_opts=dense, lm_opts=None, sparse_opts=sparse) return optimizer
- prepare_input_encoder(train_data)
Preparing input layer for training or inference.
The input layer can pre-compute node features in the preparing step if needed, e.g., pre-compute all BERT embeddings.
Default: do nothing
Parameters
- train_data: GSgnnData
Graph data
- property device
Return the device where the model runs.
This implementation assumes that all model parameters are on the same device.