HGTLayerwithEdgeFeat
- class graphstorm.model.HGTLayerwithEdgeFeat(in_dim, out_dim, ntypes, canonical_etypes, num_heads, edge_feat_name=None, edge_feat_mp_op='concat', activation=None, dropout=0.2, norm='layer', num_ffn_layers_in_gnn=0, fnn_activation=<function relu>)
Bases:
HGTLayerHeterogenous graph transformer (HGT) layer with edge feature supported.
New in version 0.4.1: In version 0.4.1, add a new HGT layer that supports edge features.
This class extends from HGTLayer.
Implementation in this class uses a simple idea to include edge features into the original HGT model, i.e., combine embeddings of source node with embeddings of edge as the new K and V, then use this new K and V in HGT formulas. And the way of combination is same as the RGCN conv model, including concat, add, sub, mul, and div. Then the formula of computing the message to send on each edge \((s, e, t)\) become:
\[\begin{split}MSG-head^i(s, e, t) = \text{M-Linear}^i_{\tau(s)}(H^{(l-1)}[s] \text{ op } EF_{e}) W^{MSG}_{\phi(e)} \\\end{split}\]where \(\text{op}\) is one of the add, sub, mul, and div operators, and the \(EF_{e}\) is the edge feature of the \(\phi(e)\) edge type.
For the concat operator, the formula is
\[\begin{split}MSG-head^i(s, e, t) = (\text{M-Linear}^i_{\tau(s)}(H^{(l-1)}[s]) + \text{EF-Linear}^i_{\phi(e)}(EF_{e}))W^{MSG}_{\phi(e)} \\\end{split}\]where \(\text{EF-Linear}^i_{\phi(e)}\) is an additional weight for the \(\phi(e)\) edge type.
This formula uses a linear algebra trick to implement concatenation operation. That is, a linear computation of \(concat([e1, e2], dim=-1) @ w\) equals to the computation of \(e1 @ w1 + e2 @ w2\), where embedding \(e1\) and \(e2\) have the same dimension \((N * in\_dim)\), weight \(w\) has the dimension \((in\_dim * 2, out\_dim)\), and weight \(w1\) and \(w2\) have the same dimension \((in\_dim, out\_dim)\). Based on this trick, instead of concatenating the source node embeddings and edge embeddings first, and then use an edge type specific weights with dimension \((in\_dim * 2, out\_dim)\) for linear transformation, this implementation uses two separated weights, i.e., one for source node type, and one for edge type, for their linear transformation first, and then add the transformed embeddings togethor.
For other HGT formulas, please refer to the HGTLayer.
Parameters
- in_dim: int
Input dimension size.
- out_dim: int
Output dimension size.
- ntypes: list of str
List of node types in the format of [ntype1, ntype2, …].
- canonical_etypes: list of tuple
List of canonical edge types in the format of [(‘src_ntyp1’, ‘etype1’, ‘dst_ntype1`), …].
- num_heads: int
Number of attention heads.
- edge_feat_name: dict of list of str
User provided edge feature names in the format of {etype1:[feat1, feat2, …], etype2:[…], …}, or None if not provided.
- edge_feat_mp_op: str
The opration method to combine source node embeddings with edge embeddings in message passing. Options include
concat,add,sub,mul, anddiv.concatoperation will concatenate the source node features with edge features;addoperation will add the source node features with edge features together;suboperation will subtract the source node features by edge features;muloperation will multiply the source node features with edge features; anddivoperation will divide the source node features by edge features.- activation: callable
Activation function. Default: None.
- dropout: float
Dropout rate. Default: 0.2.
- norm: str
Normalization methods. Options:
batch,layer, andNone. Default:layer.- num_ffn_layers_in_gnn: int
Number of fnn layers between gnn layers. Default: 0.
- ffn_actication: torch.nn.functional
Activation for ffn. Default: relu.
- forward(g, h, e_h=None)
HGT with edge feature support layer forward computation.
Parameters
- g: DGLHeteroGraph
Input DGL heterogenous graph.
- h: dict of Tensor
Node features for each node type in the format of {ntype: tensor}.
- e_h: dict of Tensor
edge features for each edge type in the format of {etype: tensor}. Default is None.
Returns
dict of Tensor: New node embeddings for each node type in the format of {ntype: tensor}.