GSLMNodeEncoderInputLayer

class graphstorm.model.GSLMNodeEncoderInputLayer(g, node_lm_configs, feat_size, embed_size, num_train=0, lm_infer_batch_size=16, activation=None, dropout=0.0, use_node_embeddings=False, use_fp16=True, cached_embed_path=None, wg_cached_embed=False, force_no_embeddings=None)

Bases: GSNodeEncoderInputLayer

The node encoder input layer with language model (LM) supported for all nodes in a heterogeneous graph.

This input layer treates node features in the same way as the GSNodeEncoderInputLayer. In addition, the input layer adds LM layer on nodes with textual features and generate LM embeddings using the LM model. The LM embeddings are then added as another type of node feature.

Parameters

g: DistGraph

The input DGL distributed graph

node_lm_configs: LM config

A list of language model configurations.

feat_sizedict of int or dict of list of ints

The original feat sizes of each node type in the format of {ntype: size}. If a node have multiple feature groups, it is in the format of {ntype: [size, size, …]}.

embed_sizeint

The output embedding size.

num_train: int

The number of nodes with textual features used for LM model fine-tuning in a mini-batch. Default: 0.

lm_infer_batch_size: int

Batch size used for computing text embeddings for static LM model. Default: 16.

activationcallable

The activation function. Default: None.

dropoutfloat

The dropout parameter. Default: 0.0.

use_node_embeddingsbool

Whether to use the node embeddings for individual nodes even when node features are available. Default: False.

use_fp16bool

Whether to use float16 to store the BERT embeddings. Default: True.

cached_embed_pathstr

The path where the generated LM embeddings are cached.

Examples:

from graphstorm import get_node_feat_size
from graphstorm.model import GSgnnNodeModel, GSLMNodeEncoderInputLayer
from graphstorm.dataloading import GSgnnData
np_data = GSgnnData(...)
model = GSgnnNodeModel(...)
feat_size = get_node_feat_size(np_data.g, "feat")
node_lm_configs = [
    {
        "lm_type": "bert",
        "model_name": "bert-base-uncased",
        "gradient_checkpoint": True,
        "node_types": ['ntype1', 'ntype2']
    }
]
lm_train_nodes=10

encoder = GSLMNodeEncoderInputLayer(
    g=np_data.g,
    node_lm_configs=node_lm_configs,
    feat_size=feat_size,
    embed_size=128,
    num_train=lm_train_nodes
)
model.set_node_input_encoder(encoder)
get_general_dense_parameters()

Get dense layers’ model parameters of this node encoder input layer.

Returns

params: list of Tensors

The dense layers’ model parameters of this node encoder input layer.

get_lm_dense_parameters()

Get the language model related parameters.

Returns

list of Tensors: the language model related parameters.

prepare(g)

Preparing input layer for training or inference.

If the number of nodes for LM model fine-tuning is zero, freeze this layer.

freeze(_)

Generate LM cache.

The LM cache is used in the following cases:

  1. No need to fine-tune LMs, i.e., num_train == 0. In this case, only generate LM cache once before model training.

  2. GNN warm up when lm_freeze_epochs > 0 (controlled by trainer). Generate the emb_cache before model training. In the first lm_freeze_epochs epochs, the number of nodes with text features for LM fine-tuning will be set to 0, and the LM cache will not be refreshed.

  3. if num_train > 0, no emb_cache is used unless Case 2.

unfreeze()

Disable LM caching.

If num_train > 0, and not use LM cache, clear existing LM cache.

require_cache_embed()

Ask to cache the embeddings for inference.

Returns

bool : return True to cache the embeddings for inference.

forward(input_feats, input_nodes)

Input layer forward computation.

The forward function computes the LM embeddings and combine them with the input node features for further projection.

Parameters

input_feats: dict of Tensor

The input features in the format of {ntype: feats}.

input_nodes: dict of Tensor

The input node indexes in the format of {ntype: indexes}.

Returns

a dict of Tensor: The projected node embeddings in the format of {ntype: emb}.