RelationalAttLayer

class graphstorm.model.RelationalAttLayer(in_feat, out_feat, rel_names, num_heads, *, bias=True, activation=None, self_loop=False, dropout=0.0, num_ffn_layers_in_gnn=0, fnn_activation=<function relu>, norm=None)

Bases: Module

Relational graph attention layer from Relational Graph Attention Networks.

For the GATConv on each relation type:

\[h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)}\]

where \(\alpha_{ij}\) is the attention score between node \(i\) and node \(j\):

\[ \begin{align}\begin{aligned}\alpha_{ij}^{l} &= \mathrm{softmax_i} (e_{ij}^{l})\\e_{ij}^{l} &= \mathrm{LeakyReLU}\left(\vec{a}^T [W h_{i} \| W h_{j}]\right)\end{aligned}\end{align} \]

Note:

  • For inner relation message aggregation we use multi-head attention network.

  • For cross relation message we just use average.

Examples:

# suppose graph and input_feature are ready
from graphstorm.model import RelationalAttLayer

layer = RelationalAttLayer(
        in_feat=h_dim, out_feat=h_dim, rel_names=g.canonical_etypes,
        num_heads=4, self_loop,
        dropout, num_ffn_layers_in_gnn,
        fnn_activation, norm)
h = layer(g, input_feature)

Parameters

in_feat: int

Input feature size.

out_feat: int

Output feature size.

rel_names: list of tuple

Relation type list in the format of [(‘src_ntyp1’, ‘etype1’, ‘dst_ntype1`), …].

num_heads: int

Number of attention heads.

bias: bool

Whether to add bias. Default: True.

activation: callable

Activation function. Default: None.

self_loop: bool

Whether to include self loop message. Default: False.

dropout: float

Dropout rate. Default: 0.

num_ffn_layers_in_gnn: int

Number of fnn layers between gnn layers. Default: 0.

ffn_actication: torch.nn.functional

Activation for ffn. Default: relu.

norm: str

Normalization methods. Options:batch, layer, and None. Default: None, meaning no normalization.

forward(g, inputs)

RGAT layer forward computation.

Parameters

g: DGLHeteroGraph

Input DGL heterogenous graph.

inputs: dict of Tensor

Node features for each node type in the format of {ntype: tensor}.

Returns

dict of Tensor: New node embeddings for each node type in the format of {ntype: tensor}.