HGTLayerwithEdgeFeat

class graphstorm.model.HGTLayerwithEdgeFeat(in_dim, out_dim, ntypes, canonical_etypes, num_heads, edge_feat_name=None, edge_feat_mp_op='concat', activation=None, dropout=0.2, norm='layer', num_ffn_layers_in_gnn=0, fnn_activation=<function relu>)

Bases: HGTLayer

Heterogenous graph transformer (HGT) layer with edge feature supported.

New in version 0.4.1: In version 0.4.1, add a new HGT layer that supports edge features.

This class extends from HGTLayer.

Implementation in this class uses a simple idea to include edge features into the original HGT model, i.e., combine embeddings of source node with embeddings of edge as the new K and V, then use this new K and V in HGT formulas. And the way of combination is same as the RGCN conv model, including concat, add, sub, mul, and div. Then the formula of computing the message to send on each edge \((s, e, t)\) become:

\[\begin{split}MSG-head^i(s, e, t) = \text{M-Linear}^i_{\tau(s)}(H^{(l-1)}[s] \text{ op } EF_{e}) W^{MSG}_{\phi(e)} \\\end{split}\]

where \(\text{op}\) is one of the add, sub, mul, and div operators, and the \(EF_{e}\) is the edge feature of the \(\phi(e)\) edge type.

For the concat operator, the formula is

\[\begin{split}MSG-head^i(s, e, t) = (\text{M-Linear}^i_{\tau(s)}(H^{(l-1)}[s]) + \text{EF-Linear}^i_{\phi(e)}(EF_{e}))W^{MSG}_{\phi(e)} \\\end{split}\]

where \(\text{EF-Linear}^i_{\phi(e)}\) is an additional weight for the \(\phi(e)\) edge type.

This formula uses a linear algebra trick to implement concatenation operation. That is, a linear computation of \(concat([e1, e2], dim=-1) @ w\) equals to the computation of \(e1 @ w1 + e2 @ w2\), where embedding \(e1\) and \(e2\) have the same dimension \((N * in\_dim)\), weight \(w\) has the dimension \((in\_dim * 2, out\_dim)\), and weight \(w1\) and \(w2\) have the same dimension \((in\_dim, out\_dim)\). Based on this trick, instead of concatenating the source node embeddings and edge embeddings first, and then use an edge type specific weights with dimension \((in\_dim * 2, out\_dim)\) for linear transformation, this implementation uses two separated weights, i.e., one for source node type, and one for edge type, for their linear transformation first, and then add the transformed embeddings togethor.

For other HGT formulas, please refer to the HGTLayer.

Parameters

in_dim: int

Input dimension size.

out_dim: int

Output dimension size.

ntypes: list of str

List of node types in the format of [ntype1, ntype2, …].

canonical_etypes: list of tuple

List of canonical edge types in the format of [(‘src_ntyp1’, ‘etype1’, ‘dst_ntype1`), …].

num_heads: int

Number of attention heads.

edge_feat_name: dict of list of str

User provided edge feature names in the format of {etype1:[feat1, feat2, …], etype2:[…], …}, or None if not provided.

edge_feat_mp_op: str

The opration method to combine source node embeddings with edge embeddings in message passing. Options include concat, add, sub, mul, and div. concat operation will concatenate the source node features with edge features; add operation will add the source node features with edge features together; sub operation will subtract the source node features by edge features; mul operation will multiply the source node features with edge features; and div operation will divide the source node features by edge features.

activation: callable

Activation function. Default: None.

dropout: float

Dropout rate. Default: 0.2.

norm: str

Normalization methods. Options:batch, layer, and None. Default: layer.

num_ffn_layers_in_gnn: int

Number of fnn layers between gnn layers. Default: 0.

ffn_actication: torch.nn.functional

Activation for ffn. Default: relu.

forward(g, h, e_h=None)

HGT with edge feature support layer forward computation.

Parameters

g: DGLHeteroGraph

Input DGL heterogenous graph.

h: dict of Tensor

Node features for each node type in the format of {ntype: tensor}.

e_h: dict of Tensor

edge features for each edge type in the format of {etype: tensor}. Default is None.

Returns

dict of Tensor: New node embeddings for each node type in the format of {ntype: tensor}.