HGTLayer

class graphstorm.model.HGTLayer(in_dim, out_dim, ntypes, canonical_etypes, num_heads, activation=None, dropout=0.2, norm='layer', num_ffn_layers_in_gnn=0, fnn_activation=<function relu>)

Bases: Module

Heterogenous graph transformer (HGT) layer from Heterogeneous Graph Transformer.

Given a graph \(G(V, E)\) and input node features \(H^{(l-1)}\) in the \(l-1\) layer, it computes the new node features in the \(l\) layer as follows:

Compute a multi-head attention score for each edge \((s, e, t)\) in the graph:

\[\begin{split}Attention(s, e, t) = \text{Softmax}\left(||_{i\in[1,h]}ATT-head^i(s, e, t)\right) \\ ATT-head^i(s, e, t) = \left(K^i(s)W^{ATT}_{\phi(e)}Q^i(t)^{\top}\right)\cdot \frac{\mu_{(\tau(s),\phi(e),\tau(t)}}{\sqrt{d}} \\ K^i(s) = \text{K-Linear}^i_{\tau(s)}(H^{(l-1)}[s]) \\ Q^i(t) = \text{Q-Linear}^i_{\tau(t)}(H^{(l-1)}[t]) \\\end{split}\]

Compute the message to send on each edge \((s, e, t)\):

\[\begin{split}Message(s, e, t) = ||_{i\in[1, h]} MSG-head^i(s, e, t) \\ MSG-head^i(s, e, t) = \text{M-Linear}^i_{\tau(s)}(H^{(l-1)}[s])W^{MSG}_{\phi(e)} \\\end{split}\]

Send messages to target nodes \(t\) and aggregate:

\[\tilde{H}^{(l)}[t] = \sum_{\forall s\in \mathcal{N}(t)}\left( Attention(s,e,t) \cdot Message(s,e,t)\right)\]

Compute new node features:

\[H^{(l)}[t]=\text{A-Linear}_{\tau(t)}(\sigma(\tilde{H}^{(l)}[t])) + H^{(l-1)}[t]\]

Note:

  • Different from DGL’s HGTConv, this implementation is based on heterogeneous graphs. Other hyperparameters’ default values are same as the DGL’s HGTConv setting.

  • The cross-relation aggregation function of this implementation is mean, which was chosen by authors of the HGT paper in their contribution to DGL.

Examples:

# suppose graph and input_feature are ready
from graphstorm.model import HGTLayer

layer = HGTLayer(hid_dim, out_dim, g.ntypes, g.canonical_etypes,
                 num_heads, activation, dropout, norm)
h = layer(g, input_feature)

Parameters

in_dim: int

Input dimension size.

out_dim: int

Output dimension size.

ntypes: list of str

List of node types in the format of [ntype1, ntype2, …].

canonical_etypes: list of tuple

List of canonical edge types in the format of [(‘src_ntyp1’, ‘etype1’, ‘dst_ntype1`), …].

num_heads: int

Number of attention heads.

activation: callable

Activation function. Default: None.

dropout: float

Dropout rate. Default: 0.2.

norm: str

Normalization methods. Options:batch, layer, and None. Default: layer.

num_ffn_layers_in_gnn: int

Number of fnn layers between gnn layers. Default: 0.

ffn_actication: torch.nn.functional

Activation for ffn. Default: relu.

forward(g, h)

HGT layer forward computation.

Parameters

g: DGLHeteroGraph

Input DGL heterogenous graph.

h: dict of Tensor

Node features for each node type in the format of {ntype: tensor}.

Returns

new_h: dict of Tensor

New node embeddings for each node type in the format of {ntype: tensor}.