HGTEncoder

class graphstorm.model.HGTEncoder(g, hid_dim, out_dim, num_hidden_layers, num_heads, dropout=0.2, norm='layer', num_ffn_layers_in_gnn=0)

Bases: GraphConvEncoder, GSgnnGNNEncoderInterface

Heterogenous Graph Transformer (HGT) encoder.

The HGTEncoder employs several HGTLayer as its encoding mechanism. The HGTEncoder should be designated as the model’s encoder within Graphstorm.

Parameters

g: DistGraph

The input distributed graph.

hid_dim: int

Hidden dimension size.

out_dim: int

Output dimension size.

num_hidden_layers: int

Number of hidden layers. Total GNN layers is equal to num_hidden_layers + 1.

num_heads: int

Number of attention heads.

dropout: float

Dropout rate. Default: 0.2.

norm: str

Normalization methods. Options:batch, layer, and None. Default: layer.

num_ffn_layers_in_gnn: int

Number of fnn layers between GNN layers. Default: 0.

Examples:

# Build model and do full-graph inference on HGTEncoder
from graphstorm import get_node_feat_size
from graphstorm.model import HGTEncoder
from graphstorm.model import MLPEdgeDecoder
from graphstorm.model import GSgnnEdgeModel, GSNodeEncoderInputLayer
from graphstorm.dataloading import GSgnnData
from graphstorm.model import do_full_graph_inference

np_data = GSgnnData(...)

model = GSgnnEdgeModel(alpha_l2norm=0)
feat_size = get_node_feat_size(np_data.g, "feat")
encoder = GSNodeEncoderInputLayer(g, feat_size, 4,
                                  dropout=0,
                                  use_node_embeddings=True)
model.set_node_input_encoder(encoder)

gnn_encoder = HGTEncoder(g,
                         hid_dim=4,
                         out_dim=4,
                         num_hidden_layers=1,
                         num_heads=2,
                         dropout=0.0,
                         norm="layer",
                         num_ffn_layers_in_gnn=0)
model.set_gnn_encoder(gnn_encoder)
model.set_decoder(MLPEdgeDecoder(model.gnn_encoder.out_dims,
                                 3, multilabel=False, target_etype=("n0", "r1", "n1"),
                                 num_ffn_layers=num_ffn_layers))

h = do_full_graph_inference(model, np_data)
forward(blocks, h)

HGT encoder forward computation.

Parameters

blocks: list of DGL MFGs

Sampled subgraph in the list of DGL message flow graphs (MFGs) format. More detailed information about DGL MFG can be found in DGL Neighbor Sampling Overview.

h: dict of Tensor

Input node features for each node type in the format of {ntype: tensor}.

Returns

h: dict of Tensor

New node embeddings for each node type in the format of {ntype: tensor}.