GSPureLearnableInputLayer
- class graphstorm.model.GSPureLearnableInputLayer(g, embed_size, use_wholegraph_sparse_emb=False)
Bases:
GSNodeInputLayer- The node encoder input layer for heterogeneous graphs
that uses learnable embeddings for every node.
New in version 0.4.2: Add
GSPureLearnableInputLayerin v0.4.2 to support Knowledge graph embedding training.Parameters
- g: DistGraph
The input DGL distributed graph.
- embed_sizeint
The output embedding size.
- use_wholegraph_sparse_embbool
Whether or not to use WholeGraph to host embeddings for sparse updates. Default: False.
Examples:
from graphstorm.model import GSgnnNodeModel, GSPureLearnableInputLayer from graphstorm.dataloading import GSgnnData np_data = GSgnnData(...) model = GSgnnNodeModel(alpha_l2norm=0) encoder = GSPureLearnableInputLayer(g, embed_size=4) model.set_node_input_encoder(encoder)
- forward(input_feats, input_nodes)
Input layer forward computation.
Parameters
- input_feats: dict of Tensor
The input features in the format of {ntype: feats}. Note: will be ignored
- input_nodes: dict of Tensor
The input node indexes in the format of {ntype: indexes}.
Returns
- embs: dict of Tensor
The projected node embeddings in the format of {ntype: emb}.
- require_cache_embed()
Whether to cache the embeddings for inference.
If the input layer encoder includes heavy computations, such as BERT computations, it should return
Trueand the inference engine will cache the embeddings from the input layer encoder.Returns
bool :
Trueif we need to cache the embeddings for inference.
- get_sparse_params()
Get the sparse parameters of this input layer.
This function is normally called by optimizers to update sparse model parameters, i.e., learnable node embeddings.
Returns
list of Tensors: the sparse embeddings, or empty list if no sparse parameters.
- property in_dims
Do not accept input features.
- property out_dims
Return the number of output dimensions, which is given in class initialization.
- property use_wholegraph_sparse_emb
Return whether or not to use WholeGraph to host embeddings for sparse updates, which is given in class initialization.