EntityClassifier

class graphstorm.model.EntityClassifier(in_dim, num_classes, multilabel, dropout=0, norm=None)

Bases: GSLayer

Decoder for node classification tasks.

Parameters

in_dim: int

The input dimension size.

num_classes: int

The number of classes to predict.

multilabel: bool

Whether this is a multi-label classification decoder.

dropout: float

Dropout rate. Default: 0.

norm: str

Normalization methods. Not used, but reserved for complex node classifier implementation. Default: None.

forward(inputs)

Node classification decoder forward computation.

Parameters

inputs: Tensor

The input embeddings.

Returns

Tensor: the prediction logits.

predict(inputs)

Node classification prediction computation.

Parameters

inputs: Tensor

The input embeddings.

Returns

Tensor: argmax of the prediction results, or the maximum of the prediction results if multilabel is True.

predict_proba(inputs)

Node classification prediction computation and return normalized prediction results.

Parameters

inputs: Tensor

The input embeddings.

Returns

Tensor: normalized prediction results.

property in_dims

Return the input dimension size, which is given in class initialization.

property out_dims

Return the output dimensions size, which is given in class initialization.