WeightedLinkPredictBCELossFunc

class graphstorm.model.WeightedLinkPredictBCELossFunc(*args, **kwargs)

Bases: GSLayer

Loss function for link prediction tasks using binary cross entropy loss with weights.

The torch.nn.functional.binary_cross_entropy_with_logits is used to compute the loss. The loss function is defined as:

\[loss = - w\_e [ y \cdot \log score + (1 - y) \cdot \log (1 - score) ]\]

where y is 1 for a positive edge and 0 for a negative edge. score is the score value of e computed by the score function, w_e is the weight of an edge e, which is defined as:

\[\begin{split}w\_e = \left \{ \begin{array}{lc} 1, & \text{ if } e \in G, \\ 0, & \text{ if } e \notin G \end{array} \right.\end{split}\]

where G is the training graph.

forward(pos_score, neg_score)

The forward function.

Parameters

pos_score: dict of tuple of Tensor

The (scores, edge weight) for positive edges of each edge type.

neg_score: dict of tuple of Tensor

The (scores, edge weight) for negative edges of each edge type.

Returns

loss: Tensor

The loss value.

property in_dims

The number of input dimensions.

Returns

int : the number of input dimensions.

property out_dims

The number of output dimensions.

Returns

int : the number of output dimensions.