FocalLossFunc
- class graphstorm.model.FocalLossFunc(alpha=0.25, gamma=2)
Bases:
GSLayerFocal loss function for classification tasks.
Copy from torchvision.ops.sigmoid_focal_loss. Only with mean reduction. See more details on https://pytorch.org/vision/main/_modules/torchvision/ops/focal_loss.html.
To use focal loss, the classification task must be a binary classification task, i.e.
num_classesshould be set to 2.Parameters
- alpha: float
Weighting factor in range (0,1) to balance positive vs negative examples. Use
-1to ignore this parameter. Default:0.25.- gamma: float
Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Default:
2.
New in version 0.4.0: The
FocalLossFunc.