LinkPredictDistMultDecoder

class graphstorm.model.LinkPredictDistMultDecoder(etypes, h_dim, gamma=12.0)

Bases: LinkPredictMultiRelationLearnableDecoder

Decoder for link prediction using the DistMult as the score function.

Parameters

etypes: list of tuples

The canonical edge types of the graph in the format of [(src_ntype1, etype1, dst_ntype1), …]

h_dim: int

The input dimension size. It is the dimension for both source and destination node embeddings.

gamma: float

The gamma value for model weight initialization. Default: 40.

init_w_relation()

Initialize learnable relation embeddings.

An example:

def init_w_relation(self):
    self._w_relation = nn.Embedding(self.num_rels, self.h_dim)

    nn.init.uniform_(self._w_relation.weight, -1., 1.)
forward(g, h, e_h=None)

Link prediction decoder forward function using the DistMult as the score function.

This computes the edge score on every edge type.

Parameters

g: DGLGraph

The input graph.

h: dict of Tensor

The input node embeddings in the format of {ntype: emb}.

e_h: dict of Tensor

The input edge embeddings in the format of {(src_ntype, etype, dst_ntype): emb}. Not used, but reserved for future support of edge embeddings. Default: None.

Returns

scores: dict of Tensor

The scores for edges of all edge types in the input graph in the format of {(src_ntype, etype, dst_ntype): score}.

calc_test_scores(emb, pos_neg_tuple, neg_sample_type, device)

Compute scores for positive edges and negative edges.

Parameters

emb: dict of Tensor

Node embeddings in the format of {ntype: emb}.

pos_neg_tuple: dict of tuple

Positive and negative edges stored in a dict of tuple in the format of {(“src_ntype1”, “etype1”, “dst_ntype1” ): (pos_src_idx, neg_src_idx, pos_dst_idx, neg_dst_idx)}.

The pos_src_idx represents the postive source node indexes in the format of Torch.Tensor. The neg_src_idx represents the negative source node indexes in the format of Torch.Tensor. The pos_dst_idx represents the postive destination node indexes in the format of Torch.Tensor. The neg_dst_idx represents the negative destination node indexes in the format of Torch.Tensor.

We define positive and negative edges as:

  • The positive edges: (pos_src_idx, pos_dst_idx)

  • The negative edges: (pos_src_idx, neg_dst_idx) and (neg_src_idx, pos_dst_idx)

neg_sample_type: str

Describe how negative samples are sampled. There are two options:

  • Uniform: For each positive edge, we sample K negative edges.

  • Joint: For one batch of positive edges, we sample K negative edges.

device: th.device

Device used to compute scores.

Returns

scores: dict of tuple

Return a dictionary of edge type’s positive scores and negative scores in the format of {(src_ntype, etype, dst_ntype): (pos_scores, neg_scores)}

property in_dims

Return the input dimension size, which is given in class initialization.

property out_dims

Return 1 for link prediction tasks.