GATv2Encoder
- class graphstorm.model.GATv2Encoder(h_dim, out_dim, num_heads, num_hidden_layers=1, edge_feat_name=None, dropout=0, activation=<function relu>, last_layer_act=False, num_ffn_layers_in_gnn=0)
Bases:
GraphConvEncoderGATv2 Conv Encoder.
The
GATv2Encoderemploys severalGATv2ConvLayers as its encoding mechanism. TheGATv2Encodershould be designated as the model’s encoder within Graphstorm.Examples:
# Build model and do full-graph inference on GATv2Encoder from graphstorm import get_node_feat_size from graphstorm.model import GATv2Encoder from graphstorm.model import EntityClassifier from graphstorm.model import GSgnnNodeModel, GSNodeEncoderInputLayer from graphstorm.dataloading import GSgnnData from graphstorm.model import do_full_graph_inference np_data = GSgnnData(...) model = GSgnnNodeModel(alpha_l2norm=0) feat_size = get_node_feat_size(np_data.g, "feat") encoder = GSNodeEncoderInputLayer(g, feat_size, 4, dropout=0, use_node_embeddings=True) model.set_node_input_encoder(encoder) gnn_encoder = GATv2Encoder(4, 4, num_heads=2 num_hidden_layers=1) model.set_gnn_encoder(gnn_encoder) model.set_decoder(EntityClassifier(model.gnn_encoder.out_dims, 3, False)) h = do_full_graph_inference(model, np_data)
Parameters
- h_dim: int
Hidden dimension size.
- out_dim: int
Output dimension size.
- num_heads: int
Number of multi-heads attention heads.
- num_hidden_layers: int
Number of hidden layers. Total GNN layers is equal to
num_hidden_layers + 1.- edge_feat_name: str
Name of the edge features used.
- dropout: float
Dropout rate. Default: 0.
- activation: torch.nn.functional
Activation function. Default: relu.
- last_layer_act: bool
Whether to call activation function in the last GNN layer. Default: False.
- num_ffn_layers_in_gnn: int
Number of fnn layers between GNN layers. Default: 0.
- forward(blocks, h, edge_feats=None)
GATv2 encoder forward computation.
Parameters
- blocks: list of DGL MFGs
Sampled subgraph in the list of DGL message flow graphs (MFGs) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.
- h: dict of Tensor
Node features for the default node type in the format of {
dgl.DEFAULT_NTYPE: tensor}. The definition ofdgl.DEFAULT_NTYPEcan be found at DGL official Github site.- edge_feats: list of dict of Tensor
Input edge features for each edge type in the format of [{etype: tensor}, …], or [{}, {}. …] for zero number of edges in input blocks. The length of edge_feats should be equal to the number of gnn layers. Default is None.
Returns
- h: dict of Tensor
New node embeddings for the default node type in the format of {
dgl.DEFAULT_NTYPE: tensor}. The definition ofdgl.DEFAULT_NTYPEcan be found at DGL official Github site.