LinkPredictContrastiveDistMultDecoder

class graphstorm.model.LinkPredictContrastiveDistMultDecoder(etypes, h_dim, gamma=12.0)

Bases: LinkPredictDistMultDecoder

Decoder for link prediction designed for contrastive loss

using the DistMult as the score function.

Note:

This class is specifically implemented for contrastive loss. But it could also be used by other pair-wise loss functions for link prediction tasks.

Parameters

etypes: list of tuples

The canonical edge types of the edges used during model training in the format of [(src_ntype1, etype1, dst_ntype1), …]

h_dim: int

The input dimension size. It is the dimension for both source and destination node embeddings.

gamma: float

The gamma value for model weight initialization. Default: 40.

forward(g, h, e_h=None)

Link prediction decoder forward function using the DistMult as the score function.

This computes the edge score on every edge type.

Parameters

g: DGLGraph

The input graph.

h: dict of Tensor

The input node embeddings in the format of {ntype: emb}.

e_h: dict of Tensor

The input edge embeddings in the format of {(src_ntype, etype, dst_ntype): emb}. Not used, but reserved for future support of edge embeddings. Default: None.

Returns

scores: dict of Tensor

The scores for edges of all edge types in the input graph in the format of {(src_ntype, etype, dst_ntype): score}.