Deal with Imbalanced Labels in Classification/Regression
In some cases, the number of labels of different classes could be imbalanced, i.e., some classes have either too many or too few data points. For example, most fraud detection tasks only have a small number of fraudulent activities (positive labels) versus a huge number of legitimate activities (negative labels). Even in regression tasks, it is possible to encounter many dominant values that can cause imbalanced labels. If not handled properly, these imbalanced labels could impact classification/regression model performance a lot. For example, because too many negative labels are fit into models, models may learn to classify all unseen samples as negative. GraphStorm provides several ways to tackle the class imbalance problem.
For classification tasks, users can configure two arguments in command line interfaces (CLIs), the
imbalance_class_weights
and class_loss_func
.
The imbalance_class_weights
allows users to give scale weights for each class, hence forcing models
to learn more on the classes with higher scale weight. For example, if there are 10 positive labels versus
90 negative labels, you can set imbalance_class_weights
to be 0.1, 0.9
, meaning class 0 (usually
for negative labels) has weight 0.1
, and class 1 (usually for positive labels) has weight 0.9
.
This places more importance on correctly classifying positive samples and less on negative ones. Below
is an example about how to set the imbalance_class_weights
in a YAML configuration file.
imbalance_class_weights: 0.1,0.9
You can also set focal
as the class_loss_func
configuration’s value, which will use the
focal loss function in binary classification tasks. The focal loss
function is designed for imbalanced classes. Its formula is \(loss(p_t) = -\alpha_t(1-p_t)^{\gamma}log(p_t)\),
where \(p_t=p\), if \(y=1\), otherwise, \(p_t = 1-p\). Here \(p\) is the probability of output
in a binary classification. This function has two hyperparameters, \(\alpha\) and \(\gamma\),
corresponding to the alpha
and gamma
configuration in GraphStorm. Larger values of gamma
will help
update models on harder cases so as to detect more positive samples if the positive to negative ratio is small.
There is no clear guideline for values of alpha
. You can use its default value(0.25
) first, and then
search for optimal values. Focal loss only works for binary classification tasks, so num_classes
should be set
to 2.
Below is an example about how to set the focal loss function in a YAML configuration file.
class_loss_func: focal gamma: 10.0 alpha: 0.5
Apart from focal loss and class weights, you can also output the classification results as probabilities of positive and negative
classes by setting the value of return_proba
configuration to be true
. By default GraphStorm outputs
classification results using the argmax values, e.g., either 0s or 1s in binary tasks, which equals to using
0.5
as the threshold to classify negative from positive samples. With probabilities as outputs, you can use
different thresholds, hence being able to achieve desired outcomes. For example, if you need higher recall to catch
more suspicious positive samples, a smaller threshold, e.g., “0.25”, could classify more positive cases. Or you may
use methods like ROC curve or Precision-Recall curve to determine the optimal threshold. Below is an example about how
to set the return_proba
in a YAML configuration file.
return_proba: true
For regression tasks where there are some dominant values, e.g., 0s, in labels, GraphStorm provides the
shrinkage loss function,
which can be set by using shrinkage
as value of the regression_loss_func
configuration. Its formula is
\(loss = l^2/(1 + \exp \left( \alpha \cdot (\gamma - l)\right))\), where \(l\) is the absolute difference
between predictions and labels. The shrinkage loss function also has the \(\alpha\) and \(\gamma\) hyperparameters.
You can use the same alpha
and gamma
configuration as the focal loss function to modify their values. The shrinkage
loss penalizes the importance of easy samples (when \(l < 0.5\)) and keeps the loss of hard samples unchanged. Below is
an example about how to set the shrinkage loss function in a YAML configuration file.
regression_loss_func: shrinkage gamma: 0.2 alpha: 5