graphstorm.model.GSPureLMNodeInputLayer

class graphstorm.model.GSPureLMNodeInputLayer(g, node_lm_configs, num_train=0, lm_infer_batch_size=16, use_fp16=True)

The input embedding layer with language model only for all nodes in a heterogeneous graph.

The input layer only has the language model layer and each node type should have text feature. The output dimension will be the same as the output dimension of the language model.

Use GSLMNodeEncoderInputLayer if there are extra node features or a different output dimension is required.

Parameters

g: DistGraph

The distributed graph.

node_lm_configs:

A list of language model configurations.

num_train: int

Number of trainable texts. Default: 0

lm_infer_batch_size: int

Batch size used for computing text embeddings for static lm model. Default: 16

use_fp16bool

Use float16 to store BERT embeddings. Default: True

Examples:

from graphstorm.model import GSgnnNodeModel, GSPureLMNodeInputLayer
from graphstorm.dataloading import GSgnnNodeTrainData

node_lm_configs = [
    {
        "lm_type": "bert",
        "model_name": "bert-base-uncased",
        "gradient_checkpoint": True,
        "node_types": ['a']
    }
]
np_data = GSgnnNodeTrainData(...)
model = GSgnnNodeModel(...)
lm_train_nodes=10
encoder = GSPureLMNodeInputLayer(g=np_data.g, node_lm_configs=node_lm_configs,
                                num_train=lm_train_nodes)
model.set_node_input_encoder(encoder)