GSgnnModelBase

class graphstorm.model.GSgnnModelBase(*args, **kwargs)

Bases: Module

GraphStorm GNN model base class.

Any GNN model trained by GraphStorm should inherit from this class. It contains some abstract methods that should be defined by the inherited classes. It also provides some utility methods.

abstract restore_dense_model(restore_model_path, model_layer_to_load=None)

Restore dense models, e.g., GNN Encoders, Decoders, etc.

All model parameters except for learnable node embeddings, i.e., dgl.distributed.DistEmbedding, are restored by this function. This fuction will go though all the model layers and load the corresponding parameters from restore_model_path.

In some cases, users can choose which model layer(s) to load by setting model_layer_to_load. model_layer_to_load is designed to indicate the names of model layer(s) that should be restored.

Example Implementation:

The code below provides examplary implementation of this abstract method.

To restore model parameters for a model with all three layers of a GraphStorm GNN model, including an input layer, a GNN layer and a decoder layer:

# suppose we are going to load all three layers.
input_encoder = self.input_encoder
gnn_model = self.gnn_model
decoder = self.decoder

checkpoint = th.load(os.path.join(model_path, 'model.bin'),
                     map_location='cpu',
                     weights_only=True)

assert 'gnn' in checkpoint
assert 'input' in checkpoint
assert 'decoder' in checkpoint

input_encoder.load_state_dict(checkpoint['input'], strict=False)
gnn_model.load_state_dict(checkpoint['gnn'])
decoder.load_state_dict(checkpoint['decoder'])

Parameters

restore_model_pathstr

The path where the model was stored.

model_layer_to_load: list of str

List of model layers to load. This arguement is used to indicate which model layer(s) are going to be restored from the model checkpoint. Default: None.

abstract restore_sparse_model(restore_model_path)

Restore sparse models, e.g., learnable node embeddings.

Learnable node embeddings are restored by this function.

Example Implementation:

The code below provides examplary implementation of this abstract method.

To load sparse model parameters for a node_input_encoder:

from graphstorm.model.utils import load_sparse_emb

for ntype, sparse_emb in sparse_embeds.items():
    load_sparse_emb(sparse_emb, os.path.join(model_path, ntype))

Parameters

restore_model_pathstr

The path where the model was stored.

abstract save_dense_model(model_path)

Save dense models, e.g., GNN Encoders, Decoders, etc.

All model parameters except for learnable node embeddings, i.e., dgl.distributed.DistEmbedding, are saved by this function. This fuction should go though all model layers and save the correspoinding parameters under model_path.

Example Implementation:

The code below provides an examplary implementation of this abstract method.

# This function is only called by rank 0
input_encoder = self.input_encoder
gnn_model = self.gnn_model
decoder = self.decoder

model_states = {}
model_states['gnn'] = gnn_model.state_dict()
model_states['input'] = input_encoder.state_dict()
model_states['decoder'] = decoder.state_dict()

os.makedirs(model_path, exist_ok=True)
# mode 767 means rwx-rw-rwx:
os.chmod(model_path, 0o767)
th.save(model_states, os.path.join(model_path, 'model.bin'))

Parameters

model_pathstr

The path where all model parameters and optimizer states will be saved.

abstract save_sparse_model(model_path)

Save sparse models, e.g., learnable node embeddings.

Learnable node embeddings are saved by this function. Saving learnable node embeddings only works when 1) the training task is run on a single machine or 2) the training task is running on a distributed environment with a shared file system.

Example Implementation:

The code below provides an examplary implementation of this abstract method.

The implementation of save_sparse_model usually includes two steps:

Step 1: Create a path to save the learnable node embeddings.

from graphstorm.model.util import create_sparse_emb_path

for ntype, sparse_emb in sparse_embeds.items():
    create_sparse_emb_path(model_path, ntype)
# make sure rank 0 creates the folder and change permission first

Step 2: Save learnable node embeddings.

from graphstorm.model.utils import save_sparse_emb

for ntype, sparse_emb in sparse_embeds.items():
    save_sparse_emb(model_path, sparse_emb, ntype)

Parameters

model_pathstr

The path where all model sparse parameters will be saved.

restore_model(restore_model_path, model_layer_to_load=None)

Restore saved checkpoints of a GNN model.

Users who want to overwrite this method should load the parameters of the GNN model. This method does not need to load the optimizer state.

Examples

Load a model from “/tmp/checkpoints”.

# CustomGSgnnModel is a child class of GSgnnModelBase
model = CustomGSgnnModel()

# Restore model parameters from "/tmp/checkpoints"
model.restore_model("/tmp/checkpoints")

Parameters

restore_model_pathstr

The path where the model was stored.

model_layer_to_load: list of str

list of model layers to load. Supported layers include: “embed”, “gnn”, “decoder”.

save_model(model_path)

Save a trained model.

When saving a model, need to save both the dense parameters and sparse parameters.

Examples

Save a model into “/tmp/checkpoints”.

# CustomGSgnnModel is a child class of GSgnnModelBase
model = CustomGSgnnModel()

# Model parameters will be saved into "/tmp/checkpoints"
model.save_model("/tmp/checkpoints")

Parameters

model_pathstr

The path where all model parameters and optimizer states will be saved.

abstract create_optimizer()

Create the optimizer that optimizes the model.

Users who want to customize a model should define a optimizer for this model. By using this method, users can define their customized optimization algorithm, the learning rate as well as any other hyperparameters.

A model may require multiple optimizers. For example, we should define an optimizer for sparse embeddings and an optimizer for the dense parameters of a GNN model. In this case, a user should use a GSOptimizer to combine these optimizers.

Example:

Case 1: if there is only one optimizer:

def create_optimizer(self):
    # define torch.optim.Optimizer
    return optimizer

Case 2: if there are both dense and sparse optimizers:

def create_optimizer(self):
    dense = [dense_opt] # define torch.optim.Optimizer
    sparse = [sparse_opt] # define dgl sparse Optimizer
    optimizer = GSOptimizer(dense_opts=dense,
                            lm_opts=None,
                            sparse_opts=sparse)
    return optimizer
prepare_input_encoder(train_data)

Preparing input layer for training or inference.

The input layer can pre-compute node features in the preparing step if needed, e.g., pre-compute all BERT embeddings.

Default: do nothing

Parameters

train_data: GSgnnData

Graph data

property device

Return the device where the model runs.

This implementation assumes that all model parameters are on the same device.