graphstorm.model.GSNodeEncoderInputLayer

class graphstorm.model.GSNodeEncoderInputLayer(g, feat_size, embed_size, activation=None, dropout=0.0, use_node_embeddings=False, force_no_embeddings=None, num_ffn_layers_in_input=0, ffn_activation=<function relu>, cache_embed=False, use_wholegraph_sparse_emb=False)

The input encoder layer for all nodes in a heterogeneous graph.

The input layer adds learnable embeddings on nodes if the nodes do not have features. It adds a linear layer on nodes with node features and the linear layer projects the node features to a specified dimension. A user can add learnable embeddings on the nodes with node features. In this case, the input layer combines the node features with the learnable embeddings and project them to the specified dimension.

Parameters

g: DistGraph

The distributed graph

feat_sizedict of int

The original feat sizes of each node type

embed_sizeint

The embedding size

activationfunc

The activation function

dropoutfloat

The dropout parameter

use_node_embeddingsbool

Whether we will use learnable embeddings for individual nodes even when node features are available.

force_no_embeddingslist of str

The list node types that are forced to not have learnable embeddings.

num_ffn_layers_in_input: int, optional

Number of layers of feedforward neural network for each node type in the input layers

ffn_activationcallable

The activation function for the feedforward neural networks.

cache_embedbool

Whether or not to cache the embeddings.

use_wholegraph_sparse_embbool

Whether or not to use WholeGraph to host embeddings for sparse updates.

Examples:

from graphstorm import get_node_feat_size
from graphstorm.model import GSgnnNodeModel, GSNodeEncoderInputLayer
from graphstorm.dataloading import GSgnnNodeTrainData

np_data = GSgnnNodeTrainData(...)

model = GSgnnEdgeModel(alpha_l2norm=0)
feat_size = get_node_feat_size(np_data.g, 'feat')
encoder = GSNodeEncoderInputLayer(g, feat_size, 4,
                                  dropout=0,
                                  use_node_embeddings=True)
model.set_node_input_encoder(encoder)