SAGEConv

class graphstorm.model.SAGEConv(in_feat, out_feat, aggregator_type='mean', bias=True, dropout=0.0, activation=<function relu>, num_ffn_layers_in_gnn=0, ffn_activation=<function relu>, norm=None)

Bases: Module

GraphSage Convolutional layer from Inductive Representation Learning on Large Graphs.

The message passing formulas of SAGEConv are:

\[ \begin{align}\begin{aligned}h_{\mathcal{N}(i)}^{(l+1)} &= \mathrm{aggregate} \left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)\\h_{i}^{(l+1)} &= \sigma \left(W \cdot \mathrm{concat} (h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1}) \right)\\h_{i}^{(l+1)} &= \mathrm{norm}(h_{i}^{(l+1)})\end{aligned}\end{align} \]

Note:

  • SAGEConv is only effective on homogeneous graphs.

Examples:

# suppose graph and input_feature are ready
from graphstorm.model import SAGEConv

layer = SAGEConv(h_dim, h_dim, aggregator_type,
                 bias, activation, dropout,
                 num_ffn_layers_in_gnn, norm)
h = layer(g, input_feature)

Parameters

in_feat: int

Input feature size.

out_feat: int

Output feature size.

aggregator_type: str

Message aggregation type. Options: mean, gcn, pool, lstm. Default: mean.

bias: bool

Whether to add bias. Default: True.

dropout: float

Dropout rate. Default: 0.

activation: torch.nn.functional

Activation function. Default: relu.

num_ffn_layers_in_gnn: int

Number of fnn layers between gnn layers. Default: 0.

ffn_actication: torch.nn.functional

Activation for ffn. Default: relu.

norm: str

Normalization methods. Options:batch, layer, and None. Default: None, meaning no normalization.

forward(g, inputs)

GraphSage layer forward computation.

Parameters

g: DGLHeteroGraph

Input DGL heterogenous graph.

inputs: dict of Tensor

Node features for the default node type in the format of {dgl.DEFAULT_NTYPE: tensor}. The definition of dgl.DEFAULT_NTYPE can be found at DGL official Github site.

Returns

dict of Tensor: New node embeddings for the default node type in the format of

{dgl.DEFAULT_NTYPE: tensor}. The definition of dgl.DEFAULT_NTYPE can be found at DGL official Github site.