LinkPredictBCELossFunc

class graphstorm.model.LinkPredictBCELossFunc(*args, **kwargs)

Bases: GSLayer

Loss function for link prediction tasks using binary cross entropy loss.

The torch.nn.functional.binary_cross_entropy_with_logits is used to compute the loss. The loss function is defined as:

\[loss = - y \cdot \log score + (1 - y) \cdot \log (1 - score)\]

where y is 1 for a positive edge and 0 for a negative edge. score is the score value of the edges computed by the score function.

forward(pos_score, neg_score)

The forward function.

Parameters

pos_score: dict of Tensor

The scores for positive edges of each edge type.

neg_score: dict of Tensor

The scores for negative edges of each edge type.

Returns

loss: Tensor

The loss value.

property in_dims

The number of input dimensions.

Returns

int : the number of input dimensions.

property out_dims

The number of output dimensions.

Returns

int : the number of output dimensions.