WeightedLinkPredictAdvBCELossFunc

class graphstorm.model.WeightedLinkPredictAdvBCELossFunc(adversarial_temperature)

Bases: LinkPredictAdvBCELossFunc

Binary cross entropy loss function for link prediction tasks with adversarial loss for negative samples and weights on positive samples.

The loss function of a positive edge is defined as:

\[loss_{pos} = - w * \log score\]

where score is the score value of the positive edges computed by the score function, and w is the weight of each positive edge. The loss of the negative edges is the same as LinkPredictAdvBCELossFunc.

The final loss is defined as:

\[loss = \dfrac{\mathrm{avg}(loss_{pos}) + \mathrm{avg}(loss_{neg})}{2}\]
forward(pos_score, neg_score)

The forward function.

Parameters

pos_score: dict of tuple of Tensor

The (scores, edge weight) for positive edges of each edge type.

neg_score: dict of tuple of Tensor

The (scores, edge weight) for negative edges of each edge type.

Returns

loss: Tensor

The loss value.