GATConvwithEdgeFeat
- class graphstorm.model.GATConvwithEdgeFeat(in_feats, out_feats, num_heads, edge_feat_mp_op='concat', feat_drop=0.0, attn_drop=0.0, negative_slope=0.2, self_loop=True, activation=None, bias=True)
Bases:
ModuleGraph attention layer with edge feature supported in message passing computation.
New in version 0.4.1: Add GATConvwithEdgeFeat class in v0.4.1 to support edge feature in message passing computation.
Parameters
- in_feat: int
Input feature size.
- out_feat: int
Output feature size.
- num_headsint
Number of heads in Multi-Head Attention.
- edge_feat_mp_op: str
The opration method to combine source node embeddings with edge embeddings in message passing. Options include concat, add, sub, mul, and div.
concatoperation will concatenate the source node features with edge features;addoperation will add the source node features with edge features together;suboperation will subtract the source node features by edge features;muloperation will multiply the source node features with edge features; anddivoperation will divide the source node features by edge features.- feat_dropfloat, optional
Dropout rate on feature. Defaults:
0.- attn_dropfloat, optional
Dropout rate on attention weight. Defaults:
0.- negative_slopefloat, optional
LeakyReLU angle of negative slope. Defaults:
0.2.- residualbool, optional
If True, use residual connection. Defaults:
False.- activationcallable activation function, optional
If not None, applies an activation function to the updated node features. Default:
None.- bias: bool
Whether to add bias. Default: True.
- forward(rel_graph, inputs, get_attention=False, weight=None, edge_weight=None)
GAT conv forward computation with edge feature.
Parameters
- rel_graph: DGLGraph
Input DGL heterogenous graph with one edge type only.
- inputs: tuple of Tensors
Tuple of input node and edge features. For example, (src_inputs, dst_inputs, edge_inputs)
- get_attentionbool, optional
Whether to return the attention values. Default to False.
- weight: dict of Tensor
optional external node weight tensor. Not implemented. Reserved for future use.
- edge_weight: Tensor
optional external edge weight tensor. Not implemented. Reserved for future use.
Returns
- h: Tensor
New node embeddings for destination node type.