graphstorm.model.GSLMNodeEncoderInputLayer
- class graphstorm.model.GSLMNodeEncoderInputLayer(g, node_lm_configs, feat_size, embed_size, num_train=0, lm_infer_batch_size=16, activation=None, dropout=0.0, use_node_embeddings=False, use_fp16=True, cached_embed_path=None, wg_cached_embed=False, force_no_embeddings=None)
The input encoder layer with language model for all nodes in a heterogeneous graph.
The input layer adds language model layer on nodes with textual node features and generate LM embeddings using the LM model. The LM embeddings are then treated as node features.
The input layer adds learnable embeddings on nodes if the nodes do not have features. It adds a linear layer on nodes with node features and the linear layer projects the node features to a specified dimension. A user can add learnable embeddings on the nodes with node features. In this case, the input layer combines the node features with the learnable embeddings and project them to the specified dimension.
Parameters
- g: DistGraph
The distributed graph
- node_lm_configs:
A list of language model configurations.
- feat_sizedict of int
The original feat sizes of each node type
- embed_sizeint
The embedding size
- num_train: int
Number of trainable texts. Default: 0
- lm_infer_batch_size: int
Batch size used for computing text embeddings for static lm model. Default: 16
- activationfunc
The activation function. Default: None
- dropoutfloat
The dropout parameter. Default: 0.0
- use_node_embeddingsbool
Whether we will use the node embeddings for individual nodes even when node features are available. Default: False
- use_fp16bool
Use float16 to store the BERT embeddings. Default: True
- cached_embed_pathstr
The path where the LM embeddings are cached.
Examples:
from graphstorm import get_node_feat_size from graphstorm.model import GSgnnNodeModel, GSLMNodeEncoderInputLayer from graphstorm.dataloading import GSgnnNodeTrainData np_data = GSgnnNodeTrainData(...) model = GSgnnNodeModel(...) feat_size = get_node_feat_size(np_data.g, 'feat') node_lm_configs = [{"lm_type": "bert", "model_name": "bert-base-uncased", "gradient_checkpoint": True, "node_types": ['a']}] lm_train_nodes=10 encoder = GSLMNodeEncoderInputLayer( g=np_data.g, node_lm_configs=node_lm_configs, feat_size=feat_size, embed_size=128, num_train=lm_train_nodes ) model.set_node_input_encoder(encoder)