ShrinkageLossFunc

class graphstorm.model.ShrinkageLossFunc(alpha=10, gamma=0.2)

Bases: GSLayer

Shrinkage Loss for imbalanced regression tasks.

The shrinkage loss is defined as:

\[ \begin{align}\begin{aligned}loss = \frac{l^2}{1 + \exp \left( \alpha \cdot (\gamma - l) \right)}\\where l is the absolute difference between the predicted value and the groud truth. \alpha and \gamma are hyper-parameters controlling the shrinkage speed and the localization respectively.\end{aligned}\end{align} \]

The shrinkage loss only penalizes the importance of easy samples (when l < 0.5) and keeps the loss of hard samples unchanged.

# pylint: disable=line-too-long For more details, please refer to the paper “Deep Regression Tracking with Shrinkage Loss” (https://openaccess.thecvf.com/content_ECCV_2018/html/Xiankai_Lu_Deep_Regression_Tracking_ECCV_2018_paper.html)

Parameters

alpha: float

A hyper-parameter controls the loss shrinkage speed when the prediction error decreases. Default: 10..

gamma: float

A hyper-parameter controls the localization of the loss regarding to the x-axis. Default: 0.2.

New in version 0.4.1: Add shrinkage loss for regressoin tasks.

forward(logits, labels)

The forward function.

Parameters

logits: torch.Tensor

The prediction results.

labels: torch.Tensor

The training labels.

Returns

loss: Tensor

The loss value.

property in_dims

The number of input dimensions.

Returns

int : the number of input dimensions.

property out_dims

The number of output dimensions.

Returns

int : the number of output dimensions.