ClassifyLossFunc

class graphstorm.model.ClassifyLossFunc(multilabel, multilabel_weights=None, imbalance_class_weights=None)

Bases: GSLayer

Loss function for classification tasks.

If multilabel is set to True, the torch.nn.BCEWithLogitsLoss is used, otherwise the torch.nn.CrossEntropyLoss is used.

Parameters

multilabelbool

Whether this is multi-label classification.

multilabel_weightsTensor

The label weights for multi-label classifciation. Default: None

imbalance_class_weightsTensor

The class weights for imbalanced classes. Default: None

forward(logits, labels)

The forward function.

Parameters

logits: torch.Tensor

The prediction results.

labels: torch.Tensor

The training labels.

Returns

loss: Tensor

The loss value.

property in_dims

The number of input dimensions.

Returns

int : the number of input dimensions.

property out_dims

The number of output dimensions.

Returns

int : the number of output dimensions.