graphstorm.model.GSPureLMNodeInputLayer
- class graphstorm.model.GSPureLMNodeInputLayer(g, node_lm_configs, num_train=0, lm_infer_batch_size=16, use_fp16=True)
The input embedding layer with language model only for all nodes in a heterogeneous graph.
The input layer only has the language model layer and each node type should have text feature. The output dimension will be the same as the output dimension of the language model.
Use GSLMNodeEncoderInputLayer if there are extra node features or a different output dimension is required.
Parameters
- g: DistGraph
The distributed graph.
- node_lm_configs:
A list of language model configurations.
- num_train: int
Number of trainable texts. Default: 0
- lm_infer_batch_size: int
Batch size used for computing text embeddings for static lm model. Default: 16
- use_fp16bool
Use float16 to store BERT embeddings. Default: True
Examples:
from graphstorm.model import GSgnnNodeModel, GSPureLMNodeInputLayer from graphstorm.dataloading import GSgnnNodeTrainData node_lm_configs = [ { "lm_type": "bert", "model_name": "bert-base-uncased", "gradient_checkpoint": True, "node_types": ['a'] } ] np_data = GSgnnNodeTrainData(...) model = GSgnnNodeModel(...) lm_train_nodes=10 encoder = GSPureLMNodeInputLayer(g=np_data.g, node_lm_configs=node_lm_configs, num_train=lm_train_nodes) model.set_node_input_encoder(encoder)