GSLMNodeEncoderInputLayer
- class graphstorm.model.GSLMNodeEncoderInputLayer(g, node_lm_configs, feat_size, embed_size, num_train=0, lm_infer_batch_size=16, activation=None, dropout=0.0, use_node_embeddings=False, use_fp16=True, cached_embed_path=None, wg_cached_embed=False, force_no_embeddings=None)
Bases:
GSNodeEncoderInputLayerThe node encoder input layer with language model (LM) supported for all nodes in a heterogeneous graph.
This input layer treates node features in the same way as the
GSNodeEncoderInputLayer. In addition, the input layer adds LM layer on nodes with textual features and generate LM embeddings using the LM model. The LM embeddings are then added as another type of node feature.Parameters
- g: DistGraph
The input DGL distributed graph
- node_lm_configs: LM config
A list of language model configurations.
- feat_sizedict of int or dict of list of ints
The original feat sizes of each node type in the format of {ntype: size}. If a node have multiple feature groups, it is in the format of {ntype: [size, size, …]}.
- embed_sizeint
The output embedding size.
- num_train: int
The number of nodes with textual features used for LM model fine-tuning in a mini-batch. Default: 0.
- lm_infer_batch_size: int
Batch size used for computing text embeddings for static LM model. Default: 16.
- activationcallable
The activation function. Default: None.
- dropoutfloat
The dropout parameter. Default: 0.0.
- use_node_embeddingsbool
Whether to use the node embeddings for individual nodes even when node features are available. Default: False.
- use_fp16bool
Whether to use float16 to store the BERT embeddings. Default: True.
- cached_embed_pathstr
The path where the generated LM embeddings are cached.
Examples:
from graphstorm import get_node_feat_size from graphstorm.model import GSgnnNodeModel, GSLMNodeEncoderInputLayer from graphstorm.dataloading import GSgnnData np_data = GSgnnData(...) model = GSgnnNodeModel(...) feat_size = get_node_feat_size(np_data.g, "feat") node_lm_configs = [ { "lm_type": "bert", "model_name": "bert-base-uncased", "gradient_checkpoint": True, "node_types": ['ntype1', 'ntype2'] } ] lm_train_nodes=10 encoder = GSLMNodeEncoderInputLayer( g=np_data.g, node_lm_configs=node_lm_configs, feat_size=feat_size, embed_size=128, num_train=lm_train_nodes ) model.set_node_input_encoder(encoder)
- get_general_dense_parameters()
Get dense layers’ model parameters of this node encoder input layer.
Returns
- params: list of Tensors
The dense layers’ model parameters of this node encoder input layer.
- get_lm_dense_parameters()
Get the language model related parameters.
Returns
list of Tensors: the language model related parameters.
- prepare(g)
Preparing input layer for training or inference.
If the number of nodes for LM model fine-tuning is zero, freeze this layer.
- freeze(_)
Generate LM cache.
The LM cache is used in the following cases:
No need to fine-tune LMs, i.e.,
num_train == 0. In this case, only generate LM cache once before model training.GNN warm up when
lm_freeze_epochs > 0(controlled by trainer). Generate the emb_cache before model training. In the firstlm_freeze_epochsepochs, the number of nodes with text features for LM fine-tuning will be set to 0, and the LM cache will not be refreshed.if
num_train > 0, no emb_cache is used unless Case 2.
- unfreeze()
Disable LM caching.
If
num_train > 0, and not use LM cache, clear existing LM cache.
- require_cache_embed()
Ask to cache the embeddings for inference.
Returns
bool : return True to cache the embeddings for inference.
- forward(input_feats, input_nodes)
Input layer forward computation.
The forward function computes the LM embeddings and combine them with the input node features for further projection.
Parameters
- input_feats: dict of Tensor
The input features in the format of {ntype: feats}.
- input_nodes: dict of Tensor
The input node indexes in the format of {ntype: indexes}.
Returns
a dict of Tensor: The projected node embeddings in the format of {ntype: emb}.