graphstorm.model.GSLMNodeEncoderInputLayer

class graphstorm.model.GSLMNodeEncoderInputLayer(g, node_lm_configs, feat_size, embed_size, num_train=0, lm_infer_batch_size=16, activation=None, dropout=0.0, use_node_embeddings=False, use_fp16=True, cached_embed_path=None, wg_cached_embed=False, force_no_embeddings=None)

The input encoder layer with language model for all nodes in a heterogeneous graph.

The input layer adds language model layer on nodes with textual node features and generate LM embeddings using the LM model. The LM embeddings are then treated as node features.

The input layer adds learnable embeddings on nodes if the nodes do not have features. It adds a linear layer on nodes with node features and the linear layer projects the node features to a specified dimension. A user can add learnable embeddings on the nodes with node features. In this case, the input layer combines the node features with the learnable embeddings and project them to the specified dimension.

Parameters

g: DistGraph

The distributed graph

node_lm_configs:

A list of language model configurations.

feat_sizedict of int

The original feat sizes of each node type

embed_sizeint

The embedding size

num_train: int

Number of trainable texts. Default: 0

lm_infer_batch_size: int

Batch size used for computing text embeddings for static lm model. Default: 16

activationfunc

The activation function. Default: None

dropoutfloat

The dropout parameter. Default: 0.0

use_node_embeddingsbool

Whether we will use the node embeddings for individual nodes even when node features are available. Default: False

use_fp16bool

Use float16 to store the BERT embeddings. Default: True

cached_embed_pathstr

The path where the LM embeddings are cached.

Examples:

from graphstorm import get_node_feat_size
from graphstorm.model import GSgnnNodeModel, GSLMNodeEncoderInputLayer
from graphstorm.dataloading import GSgnnNodeTrainData
np_data = GSgnnNodeTrainData(...)
model = GSgnnNodeModel(...)
feat_size = get_node_feat_size(np_data.g, 'feat')
node_lm_configs = [{"lm_type": "bert",
                "model_name": "bert-base-uncased",
                "gradient_checkpoint": True,
                "node_types": ['a']}]
lm_train_nodes=10

encoder = GSLMNodeEncoderInputLayer(
    g=np_data.g,
    node_lm_configs=node_lm_configs,
    feat_size=feat_size,
    embed_size=128,
    num_train=lm_train_nodes
)
model.set_node_input_encoder(encoder)