RelationalGCNEncoder
- class graphstorm.model.RelationalGCNEncoder(g, h_dim, out_dim, num_bases=-1, num_hidden_layers=1, dropout=0, use_self_loop=True, last_layer_act=False, num_ffn_layers_in_gnn=0, norm=None)
Bases:
GraphConvEncoder,GSgnnGNNEncoderInterfaceRelational graph conv encoder.
The
RelationalGCNEncoderemploys severalRelGraphConvLayeras its encoding mechanism. TheRelationalGCNEncodershould be designated as the model’s encoder within Graphstorm.Parameters
- g: DistGraph
The distributed graph.
- h_dim: int
Hidden dimension.
- out_dim: int
Output dimension.
- num_bases: int
Number of bases. If is None, use number of relation types. Default: None.
- num_hidden_layers: int
Number of hidden layers. Total GNN layers is equal to
num_hidden_layers + 1. Default: 1.- dropout: float
Dropout rate. Default 0.
- use_self_loop: bool
Whether to add selfloop. Default: True.
- last_layer_act: callable
Activation for the last layer. Default: None.
- num_ffn_layers_in_gnn: int
Number of fnn layers between GNN layers. Default: 0.
- norm: str
Normalization methods. Options:
batch,layer, andNone. Default: None, meaning no normalization.
Examples:
# Build model and do full-graph inference on RelationalGCNEncoder from graphstorm import get_node_feat_size from graphstorm.model import RelationalGCNEncoder from graphstorm.model import EntityClassifier from graphstorm.model import GSgnnNodeModel, GSNodeEncoderInputLayer from graphstorm.dataloading import GSgnnData from graphstorm.model import do_full_graph_inference np_data = GSgnnData(...) model = GSgnnNodeModel(alpha_l2norm=0) feat_size = get_node_feat_size(np_data.g, "feat") encoder = GSNodeEncoderInputLayer(g, feat_size, 4, dropout=0, use_node_embeddings=True) model.set_node_input_encoder(encoder) gnn_encoder = RelationalGCNEncoder(g, 4, 4, num_hidden_layers=1, dropout=0, use_self_loop=True, norm="batch") model.set_gnn_encoder(gnn_encoder) model.set_decoder(EntityClassifier(model.gnn_encoder.out_dims, 3, False)) h = do_full_graph_inference(model, np_data)
- forward(blocks, h)
RGCN encoder forward computation.
Parameters
- blocks: list of DGL MFGs
Sampled subgraph in the list of DGL message flow graphs (MFGs) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.
- h: dict of Tensor
Input node features for each node type in the format of {ntype: tensor}.
Returns
- h: dict of Tensor
New node embeddings for each node type in the format of {ntype: tensor}.