Deal with Imbalanced Labels in Classification/Regression

In some cases, the number of labels of different classes could be imbalanced, i.e., some classes have either too many or too few data points. For example, most fraud detection tasks only have a small number of fraudulent activities (positive labels) versus a huge number of legitimate activities (negative labels). Even in regression tasks, it is possible to encounter many dominant values that can cause imbalanced labels. If not handled properly, these imbalanced labels could impact classification/regression model performance a lot. For example, because too many negative labels are fit into models, models may learn to classify all unseen samples as negative. GraphStorm provides several ways to tackle the class imbalance problem.

For classification tasks, users can configure two arguments in command line interfaces (CLIs), the imbalance_class_weights and class_loss_func.

The imbalance_class_weights allows users to give scale weights for each class, hence forcing models to learn more on the classes with higher scale weight. For example, if there are 10 positive labels versus 90 negative labels, you can set imbalance_class_weights to be 0.1, 0.9, meaning class 0 (usually for negative labels) has weight 0.1, and class 1 (usually for positive labels) has weight 0.9. This places more importance on correctly classifying positive samples and less on negative ones. Below is an example about how to set the imbalance_class_weights in a YAML configuration file.

imbalance_class_weights: 0.1,0.9

You can also set focal as the class_loss_func configuration’s value, which will use the focal loss function in binary classification tasks. The focal loss function is designed for imbalanced classes. Its formula is \(loss(p_t) = -\alpha_t(1-p_t)^{\gamma}log(p_t)\), where \(p_t=p\), if \(y=1\), otherwise, \(p_t = 1-p\). Here \(p\) is the probability of output in a binary classification. This function has two hyperparameters, \(\alpha\) and \(\gamma\), corresponding to the alpha and gamma configuration in GraphStorm. Larger values of gamma will help update models on harder cases so as to detect more positive samples if the positive to negative ratio is small. There is no clear guideline for values of alpha. You can use its default value(0.25) first, and then search for optimal values. Focal loss only works for binary classification tasks, so num_classes should be set to 2. Below is an example about how to set the focal loss function in a YAML configuration file.

class_loss_func: focal

gamma: 10.0
alpha: 0.5

Apart from focal loss and class weights, you can also output the classification results as probabilities of positive and negative classes by setting the value of return_proba configuration to be true. By default GraphStorm outputs classification results using the argmax values, e.g., either 0s or 1s in binary tasks, which equals to using 0.5 as the threshold to classify negative from positive samples. With probabilities as outputs, you can use different thresholds, hence being able to achieve desired outcomes. For example, if you need higher recall to catch more suspicious positive samples, a smaller threshold, e.g., “0.25”, could classify more positive cases. Or you may use methods like ROC curve or Precision-Recall curve to determine the optimal threshold. Below is an example about how to set the return_proba in a YAML configuration file.

return_proba: true

For regression tasks where there are some dominant values, e.g., 0s, in labels, GraphStorm provides the shrinkage loss function, which can be set by using shrinkage as value of the regression_loss_func configuration. Its formula is \(loss = l^2/(1 + \exp \left( \alpha \cdot (\gamma - l)\right))\), where \(l\) is the absolute difference between predictions and labels. The shrinkage loss function also has the \(\alpha\) and \(\gamma\) hyperparameters. You can use the same alpha and gamma configuration as the focal loss function to modify their values. The shrinkage loss penalizes the importance of easy samples (when \(l < 0.5\)) and keeps the loss of hard samples unchanged. Below is an example about how to set the shrinkage loss function in a YAML configuration file.

regression_loss_func: shrinkage

gamma: 0.2
alpha: 5