FocalLossFunc

class graphstorm.model.FocalLossFunc(alpha=0.25, gamma=2)

Bases: GSLayer

Focal loss function for classification tasks.

Copy from torchvision.ops.sigmoid_focal_loss. Only with mean reduction. See more details on https://pytorch.org/vision/main/_modules/torchvision/ops/focal_loss.html.

To use focal loss, the classification task must be a binary classification task, i.e. num_classes should be set to 2.

Parameters

alpha: float

Weighting factor in range (0,1) to balance positive vs negative examples. Use -1 to ignore this parameter. Default: 0.25.

gamma: float

Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Default: 2.

New in version 0.4.0: The FocalLossFunc.

forward(logits, labels)

The forward function.

Parameters

logits: torch.Tensor

The prediction results.

labels: torch.Tensor

The training labels.

Returns

loss: Tensor

The loss value.

property in_dims

The number of input dimensions.

Returns

int : the number of input dimensions.

property out_dims

The number of output dimensions.

Returns

int : the number of output dimensions.