ClassifyLossFunc
- class graphstorm.model.ClassifyLossFunc(multilabel, multilabel_weights=None, imbalance_class_weights=None)
Bases:
GSLayerLoss function for classification tasks.
If multilabel is set to True, the
torch.nn.BCEWithLogitsLossis used, otherwise thetorch.nn.CrossEntropyLossis used.Parameters
- multilabelbool
Whether this is multi-label classification.
- multilabel_weightsTensor
The label weights for multi-label classifciation. Default: None
- imbalance_class_weightsTensor
The class weights for imbalanced classes. Default: None