GATConvwithEdgeFeat

class graphstorm.model.GATConvwithEdgeFeat(in_feats, out_feats, num_heads, edge_feat_mp_op='concat', feat_drop=0.0, attn_drop=0.0, negative_slope=0.2, self_loop=True, activation=None, bias=True)

Bases: Module

Graph attention layer with edge feature supported in message passing computation.

New in version 0.4.1: Add GATConvwithEdgeFeat class in v0.4.1 to support edge feature in message passing computation.

Parameters

in_feat: int

Input feature size.

out_feat: int

Output feature size.

num_headsint

Number of heads in Multi-Head Attention.

edge_feat_mp_op: str

The opration method to combine source node embeddings with edge embeddings in message passing. Options include concat, add, sub, mul, and div. concat operation will concatenate the source node features with edge features; add operation will add the source node features with edge features together; sub operation will subtract the source node features by edge features; mul operation will multiply the source node features with edge features; and div operation will divide the source node features by edge features.

feat_dropfloat, optional

Dropout rate on feature. Defaults: 0.

attn_dropfloat, optional

Dropout rate on attention weight. Defaults: 0.

negative_slopefloat, optional

LeakyReLU angle of negative slope. Defaults: 0.2.

residualbool, optional

If True, use residual connection. Defaults: False.

activationcallable activation function, optional

If not None, applies an activation function to the updated node features. Default: None.

bias: bool

Whether to add bias. Default: True.

forward(rel_graph, inputs, get_attention=False, weight=None, edge_weight=None)

GAT conv forward computation with edge feature.

Parameters

rel_graph: DGLGraph

Input DGL heterogenous graph with one edge type only.

inputs: tuple of Tensors

Tuple of input node and edge features. For example, (src_inputs, dst_inputs, edge_inputs)

get_attentionbool, optional

Whether to return the attention values. Default to False.

weight: dict of Tensor

optional external node weight tensor. Not implemented. Reserved for future use.

edge_weight: Tensor

optional external edge weight tensor. Not implemented. Reserved for future use.

Returns

h: Tensor

New node embeddings for destination node type.