graphstorm.model.HGTEncoder

class graphstorm.model.HGTEncoder(g, hid_dim, out_dim, num_hidden_layers, num_heads, dropout=0.2, norm='layer', num_ffn_layers_in_gnn=0)

Heterogenous graph transformer (HGT) encoder

The HGTEncoder employs several HGTLayers as its encoding mechanism. The HGTEncoder should be designated as the model’s encoder within Graphstorm.

Parameters g : DGLHeteroGraph

Input graph.

hid_dim: int

Hidden dimension size

out_dim: int

Output dimension size

num_hidden_layers: int

Number of hidden layers

num_heads: int

Number of heads

dropout: float

Dropout, default is 0.2

normstr, optional

Normalization Method. Default: None

num_ffn_layers_in_gnn: int

Number of ngnn gnn layers between GNN layers

Examples:

# Build model and do full-graph inference on HGTEncoder
from graphstorm import get_node_feat_size
from graphstorm.model.hgt_encoder import HGTEncoder
from graphstorm.model.edge_decoder import MLPEdgeDecoder
from graphstorm.model import GSgnnEdgeModel, GSNodeEncoderInputLayer
from graphstorm.dataloading import GSgnnNodeTrainData
from graphstorm.model import do_full_graph_inference

np_data = GSgnnNodeTrainData(...)

model = GSgnnEdgeModel(alpha_l2norm=0)
feat_size = get_node_feat_size(np_data.g, 'feat')
encoder = GSNodeEncoderInputLayer(g, feat_size, 4,
                                  dropout=0,
                                  use_node_embeddings=True)
model.set_node_input_encoder(encoder)

gnn_encoder = HGTEncoder(g,
                         hid_dim=4,
                         out_dim=4,
                         num_hidden_layers=1,
                         num_heads=2,
                         dropout=0.0,
                         norm='layer',
                         num_ffn_layers_in_gnn=0)
model.set_gnn_encoder(gnn_encoder)
model.set_decoder(MLPEdgeDecoder(model.gnn_encoder.out_dims,
                                 3, multilabel=False, target_etype=("n0", "r1", "n1"),
                                 num_ffn_layers=num_ffn_layers))

h = do_full_graph_inference(model, np_data)