graphstorm.model.RelationalGATEncoder

class graphstorm.model.RelationalGATEncoder(g, h_dim, out_dim, num_heads, num_hidden_layers=1, dropout=0, use_self_loop=True, last_layer_act=False, num_ffn_layers_in_gnn=0, norm=None)

Relational graph attention encoder

The RelationalGATEncoder employs several RelationalAttLayers as its encoding mechanism. The RelationalGATEncoder should be designated as the model’s encoder within Graphstorm.

Parameters

gDGLHeteroGraph

Input graph.

h_dim: int

Hidden dimension size

out_dim: int

Output dimension size

num_heads: int

Number of heads

num_hidden_layers: int

Num hidden layers

dropout: float

Dropout

use_self_loop: bool

Self loop

last_layer_act: bool

Whether add activation at the last layer

num_ffn_layers_in_gnn: int

Number of ngnn gnn layers between GNN layers

normstr, optional

Normalization Method. Default: None

Examples:

# Build model and do full-graph inference on RelationalGATEncoder
from graphstorm import get_node_feat_size
from graphstorm.model.rgat_encoder import RelationalGATEncoder
from graphstorm.model.node_decoder import EntityClassifier
from graphstorm.model import GSgnnNodeModel, GSNodeEncoderInputLayer
from graphstorm.dataloading import GSgnnNodeTrainData
from graphstorm.model import do_full_graph_inference

np_data = GSgnnNodeTrainData(...)

model = GSgnnNodeModel(alpha_l2norm=0)
feat_size = get_node_feat_size(np_data.g, 'feat')
encoder = GSNodeEncoderInputLayer(g, feat_size, 4,
                                  dropout=0,
                                  use_node_embeddings=True)
model.set_node_input_encoder(encoder)

gnn_encoder = RelationalGATEncoder(g, 4, 4,
                                   num_heads=2,
                                   num_hidden_layers=1,
                                   dropout=0,
                                   use_self_loop=True,
                                   norm=norm)
model.set_gnn_encoder(gnn_encoder)
model.set_decoder(EntityClassifier(model.gnn_encoder.out_dims, 3, False))

h = do_full_graph_inference(model, np_data)